Source code for hyperion.torch.narchs.tdnn_factory

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

from jsonargparse import ArgumentParser, ActionParser

from .tdnn import TDNNV1
from .etdnn import ETDNNV1
from .resetdnn import ResETDNNV1


[docs]class TDNNFactory(object):
[docs] @staticmethod def create( tdnn_type, num_enc_blocks, in_feats, enc_hid_units, enc_expand_units=None, kernel_size=3, dilation=1, dilation_factor=1, hid_act={"name": "relu6", "inplace": True}, out_units=0, out_act=None, dropout_rate=0, norm_layer=None, use_norm=True, norm_before=True, in_norm=True, ): if enc_expand_units is not None and isinstance(enc_hid_units, int): if tdnn_type != "resetdnn": enc_hid_units = (num_enc_blocks - 1) * [enc_hid_units] + [ enc_expand_units ] if tdnn_type == "tdnn": nnet = TDNNV1( num_enc_blocks, in_feats, enc_hid_units, out_units=out_units, kernel_size=kernel_size, dilation=dilation, dilation_factor=dilation_factor, hid_act=hid_act, out_act=out_act, dropout_rate=dropout_rate, norm_layer=norm_layer, use_norm=use_norm, norm_before=norm_before, in_norm=in_norm, ) elif tdnn_type == "etdnn": nnet = ETDNNV1( num_enc_blocks, in_feats, enc_hid_units, out_units=out_units, kernel_size=kernel_size, dilation=dilation, dilation_factor=dilation_factor, hid_act=hid_act, out_act=out_act, dropout_rate=dropout_rate, norm_layer=norm_layer, use_norm=use_norm, norm_before=norm_before, in_norm=in_norm, ) elif tdnn_type == "resetdnn": if enc_expand_units is None: enc_expand_units = enc_hid_units nnet = ResETDNNV1( num_enc_blocks, in_feats, enc_hid_units, enc_expand_units, out_units=out_units, kernel_size=kernel_size, dilation=dilation, dilation_factor=dilation_factor, hid_act=hid_act, out_act=out_act, dropout_rate=dropout_rate, norm_layer=norm_layer, use_norm=use_norm, norm_before=norm_before, in_norm=in_norm, ) else: raise Exception("%s is not valid TDNN network" % (tdnn_type)) return nnet
[docs] def filter_args(**kwargs): if "wo_norm" in kwargs: kwargs["use_norm"] = not kwargs["wo_norm"] del kwargs["wo_norm"] if "norm_after" in kwargs: kwargs["norm_before"] = not kwargs["norm_after"] del kwargs["norm_after"] valid_args = ( "tdnn_type", "num_enc_blocks", "enc_hid_units", "enc_expand_units", "kernel_size", "dilation", "dilation_factor", "in_norm", "hid_act", "norm_layer", "use_norm", "norm_before", "in_feats", "dropout_rate", ) args = dict((k, kwargs[k]) for k in valid_args if k in kwargs) for arg in ("enc_hid_units", "kernel_size", "dilation"): if arg in args: val = args[arg] if isinstance(val, list) and len(val) == 1: args[arg] = val[0] return args
[docs] @staticmethod def add_class_args(parser, prefix=None): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") parser.add_argument( "--tdnn-type", type=str.lower, default="resetdnn", choices=["tdnn", "etdnn", "resetdnn"], help=("TDNN type: TDNN, ETDNN, ResETDNN"), ) parser.add_argument( "--num-enc-blocks", default=9, type=int, help=("number of encoder layer blocks"), ) parser.add_argument( "--enc-hid-units", nargs="+", default=512, type=int, help=("number of encoder layer blocks"), ) parser.add_argument( "--enc-expand-units", default=None, type=int, help=("dimension of last layer of ResETDNN"), ) parser.add_argument( "--kernel-size", nargs="+", default=3, type=int, help=("kernel sizes of encoder conv1d"), ) parser.add_argument( "--dilation", nargs="+", default=1, type=int, help=("dilations of encoder conv1d"), ) parser.add_argument( "--dilation-factor", default=1, type=int, help=("dilation increment wrt previous conv1d layer"), ) try: parser.add_argument("--hid-act", default="relu6", help="hidden activation") except: pass try: parser.add_argument( "--norm-layer", default=None, choices=[ "batch-norm", "group-norm", "instance-norm", "instance-norm-affine", "layer-norm", ], help="type of normalization layer", ) except: pass parser.add_argument( "--in-norm", default=False, action="store_true", help="batch normalization at the input", ) try: parser.add_argument( "--wo-norm", default=False, action="store_true", help="without batch normalization", ) except: pass try: parser.add_argument( "--norm-after", default=False, action="store_true", help="batch normalizaton after activation", ) except: pass try: parser.add_argument("--dropout-rate", default=0, type=float, help="dropout") except: pass try: parser.add_argument( "--in-feats", default=None, type=int, help=( "input feature dimension, " "if None it will try to infer from encoder network" ), ) except: pass if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='TDNN options') add_argparse_args = add_class_args