Source code for hyperion.torch.narchs.resnet_factory

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

from jsonargparse import ArgumentParser, ActionParser

from .resnet import *

resnet_dict = {
    "resnet18": ResNet18,
    "resnet34": ResNet34,
    "resnet50": ResNet50,
    "resnet101": ResNet101,
    "resnet152": ResNet152,
    "resnext50_32x4d": ResNext50_32x4d,
    "resnext101_32x8d": ResNext101_32x8d,
    "wideresnet50": WideResNet50,
    "wideresnet101": WideResNet101,
    "lresnet18": LResNet18,
    "lresnet34": LResNet34,
    "lresnet50": LResNet50,
    "lresnext50_4x4d": LResNext50_4x4d,
    "seresnet18": SEResNet18,
    "seresnet34": SEResNet34,
    "seresnet50": SEResNet50,
    "seresnet101": SEResNet101,
    "seresnet152": SEResNet152,
    "seresnext50_32x4d": SEResNext50_32x4d,
    "seresnext101_32x8d": SEResNext101_32x8d,
    "sewideresnet50": SEWideResNet50,
    "sewideresnet101": SEWideResNet101,
    "selresnet18": SELResNet18,
    "selresnet34": SELResNet34,
    "selresnet50": SELResNet50,
    "selresnext50_4x4d": SELResNext50_4x4d,
    "tseresnet18": TSEResNet18,
    "tseresnet34": TSEResNet34,
    "tseresnet50": TSEResNet50,
    "tseresnet101": TSEResNet101,
    "tseresnet152": TSEResNet152,
    "tseresnext50_32x4d": TSEResNext50_32x4d,
    "tseresnext101_32x8d": TSEResNext101_32x8d,
    "tsewideresnet50": TSEWideResNet50,
    "tsewideresnet101": TSEWideResNet101,
    "tselresnet18": TSELResNet18,
    "tselresnet34": TSELResNet34,
    "tselresnet50": TSELResNet50,
    "tselresnext50_4x4d": TSELResNext50_4x4d,
    "res2net18": Res2Net18,
    "res2net34": Res2Net34,
    "res2net50": Res2Net50,
    "res2net101": Res2Net101,
    "res2net152": Res2Net152,
    "res2next50_32x4d": Res2Next50_32x4d,
    "res2next101_32x8d": Res2Next101_32x8d,
    "wideres2net50": WideRes2Net50,
    "wideres2net101": WideRes2Net101,
    "lres2net50": LRes2Net50,
    "lres2next50_4x4d": LRes2Next50_4x4d,
    "seres2net18": SERes2Net18,
    "seres2net34": SERes2Net34,
    "seres2net50": SERes2Net50,
    "seres2net101": SERes2Net101,
    "seres2net152": SERes2Net152,
    "seres2next50_32x4d": SERes2Next50_32x4d,
    "seres2next101_32x8d": SERes2Next101_32x8d,
    "sewideres2net50": SEWideRes2Net50,
    "sewideres2net101": SEWideRes2Net101,
    "selres2net50": SELRes2Net50,
    "selres2next50_4x4d": SELRes2Next50_4x4d,
    "tseres2net18": TSERes2Net18,
    "tseres2net34": TSERes2Net34,
    "tseres2net50": TSERes2Net50,
    "tseres2net101": TSERes2Net101,
    "tseres2net152": TSERes2Net152,
    "tseres2next50_32x4d": TSERes2Next50_32x4d,
    "tseres2next101_32x8d": TSERes2Next101_32x8d,
    "tsewideres2net50": TSEWideRes2Net50,
    "tsewideres2net101": TSEWideRes2Net101,
    "tselres2net50": TSELRes2Net50,
    "tselres2next50_4x4d": TSELRes2Next50_4x4d,
    "lresnet34_345": LResNet34_345,
}


[docs]class ResNetFactory(object):
[docs] @staticmethod def create( resnet_type, in_channels, conv_channels=64, base_channels=64, out_units=0, hid_act={"name": "relu6", "inplace": True}, out_act=None, in_kernel_size=7, in_stride=2, zero_init_residual=False, groups=1, replace_stride_with_dilation=None, dropout_rate=0, norm_layer=None, norm_before=True, do_maxpool=True, in_norm=True, se_r=16, in_feats=None, res2net_scale=4, res2net_width_factor=1, ): try: resnet_class = resnet_dict[resnet_type] except: raise Exception("%s is not valid ResNet network" % (resnet_type)) resnet = resnet_class( in_channels, conv_channels=conv_channels, base_channels=base_channels, out_units=out_units, hid_act=hid_act, out_act=out_act, in_kernel_size=in_kernel_size, in_stride=in_stride, zero_init_residual=zero_init_residual, groups=groups, replace_stride_with_dilation=replace_stride_with_dilation, dropout_rate=dropout_rate, norm_layer=norm_layer, norm_before=norm_before, do_maxpool=do_maxpool, in_norm=in_norm, se_r=se_r, in_feats=in_feats, res2net_scale=res2net_scale, res2net_width_factor=res2net_width_factor, ) return resnet
[docs] def filter_args(**kwargs): if "norm_after" in kwargs: kwargs["norm_before"] = not kwargs["norm_after"] del kwargs["norm_after"] if "no_maxpool" in kwargs: kwargs["do_maxpool"] = not kwargs["no_maxpool"] del kwargs["no_maxpool"] valid_args = ( "resnet_type", "in_channels", "conv_channels", "base_channels", "out_units", "hid_act", "out_act", "in_kernel_size", "in_stride", "zero_init_residual", "groups", "replace_stride_with_dilation", "dropout_rate", "in_norm", "norm_layer", "norm_before", "do_maxpool", "se_r", "res2net_scale", "res2net_width_factor", ) args = dict((k, kwargs[k]) for k in valid_args if k in kwargs) return args
[docs] @staticmethod def add_class_args(parser, prefix=None): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") resnet_types = resnet_dict.keys() parser.add_argument( "--resnet-type", type=str.lower, default="lresnet34", choices=resnet_types, help=("ResNet type"), ) parser.add_argument( "--in-channels", default=1, type=int, help=("number of input channels") ) parser.add_argument( "--conv-channels", default=64, type=int, help=("number of output channels in input convolution "), ) parser.add_argument( "--base-channels", default=64, type=int, help=("base channels of first ResNet block"), ) parser.add_argument( "--in-kernel-size", default=7, type=int, help=("kernel size of first convolution"), ) parser.add_argument( "--in-stride", default=2, type=int, help=("stride of first convolution") ) parser.add_argument( "--groups", default=1, type=int, help=("number of groups in residual blocks convolutions"), ) try: parser.add_argument( "--norm-layer", default=None, choices=[ "batch-norm", "group-norm", "instance-norm", "instance-norm-affine", "layer-norm", ], help="type of normalization layer", ) except: pass parser.add_argument( "--in-norm", default=False, action="store_true", help="batch normalization at the input", ) parser.add_argument( "--no-maxpool", default=False, action="store_true", help="don't do max pooling after first convolution", ) parser.add_argument( "--zero-init-residual", default=False, action="store_true", help="Zero-initialize the last BN in each residual branch", ) # parser.add_argument('--replace-stride-with-dilation', default=None, nargs='+', type=bool, # help='replaces strides with dilations to increase context without downsampling') parser.add_argument( "--se-r", default=16, type=int, help=("squeeze ratio in squeeze-excitation blocks"), ) parser.add_argument( "--res2net-scale", default=4, type=int, help=("scale parameter for res2net") ) parser.add_argument( "--res2net-width-factor", default=1, type=float, help=("multiplicative factor for the internal width of res2net"), ) try: parser.add_argument("--hid-act", default="relu6", help="hidden activation") except: pass try: parser.add_argument( "--norm-after", default=False, action="store_true", help="batch normalizaton after activation", ) except: pass try: parser.add_argument("--dropout-rate", default=0, type=float, help="dropout") except: pass if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='ResNet options') add_argparse_args = add_class_args