Source code for hyperion.torch.narchs.resnet2d_encoder

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import math
from jsonargparse import ArgumentParser, ActionParser

import torch
import torch.nn as nn

from ..layers import ActivationFactory as AF
from ..layers import NormLayer2dFactory as NLF
from ..layer_blocks import ResNet2dBasicBlock, ResNet2dBNBlock, DC2dEncBlock
from ..layer_blocks import SEResNet2dBasicBlock, SEResNet2dBNBlock
from ..layer_blocks import Res2Net2dBasicBlock, Res2Net2dBNBlock
from .net_arch import NetArch


[docs]class ResNet2dEncoder(NetArch):
[docs] def __init__( self, in_channels=1, in_conv_channels=64, in_kernel_size=3, in_stride=1, resb_type="basic", resb_repeats=[2, 2, 2, 2], resb_channels=[64, 128, 256, 512], resb_kernel_sizes=3, resb_strides=2, resb_dilations=1, resb_groups=1, head_channels=0, hid_act="relu6", head_act=None, dropout_rate=0, se_r=16, time_se=False, in_feats=None, res2net_width_factor=1, res2net_scale=4, use_norm=True, norm_layer=None, norm_before=True, ): super().__init__() self.resb_type = resb_type bargs = {} # block's extra arguments if resb_type == "basic": self._block = ResNet2dBasicBlock elif resb_type == "bn": self._block = ResNet2dBNBlock elif resb_type == "sebasic": self._block = SEResNet2dBasicBlock bargs["se_r"] = se_r elif resb_type == "sebn": self._block = SEResNet2dBNBlock bargs["se_r"] = se_r elif resb_type in ["res2basic", "seres2basic", "res2bn", "seres2bn"]: bargs["width_factor"] = res2net_width_factor bargs["scale"] = res2net_scale if resb_type in ["seres2basic", "seres2bn"]: bargs["se_r"] = se_r bargs["time_se"] = time_se if resb_type in ["res2basic", "seres2basic"]: self._block = Res2Net2dBasicBlock else: self._block = Res2Net2dBNBlock self.in_channels = in_channels self.in_conv_channels = in_conv_channels self.in_kernel_size = in_kernel_size self.in_stride = in_stride num_superblocks = len(resb_repeats) self.resb_repeats = resb_repeats self.resb_channels = self._standarize_resblocks_param( resb_channels, num_superblocks, "resb_channels" ) self.resb_kernel_sizes = self._standarize_resblocks_param( resb_kernel_sizes, num_superblocks, "resb_kernel_sizes" ) self.resb_strides = self._standarize_resblocks_param( resb_strides, num_superblocks, "resb_strides" ) self.resb_dilations = self._standarize_resblocks_param( resb_dilations, num_superblocks, "resb_dilations" ) self.resb_groups = resb_groups self.head_channels = head_channels self.hid_act = hid_act self.head_act = head_act self.dropout_rate = dropout_rate self.use_norm = use_norm self.norm_before = norm_before self.se_r = se_r self.time_se = time_se self.in_feats = in_feats self.res2net_width_factor = res2net_width_factor self.res2net_scale = res2net_scale self.norm_layer = norm_layer norm_groups = None if norm_layer == "group-norm": norm_groups = min(np.min(resb_channels) // 2, 32) norm_groups = max(norm_groups, resb_groups) self._norm_layer = NLF.create(norm_layer, norm_groups) # stem block self.in_block = DC2dEncBlock( in_channels, in_conv_channels, in_kernel_size, stride=in_stride, activation=hid_act, dropout_rate=dropout_rate, use_norm=use_norm, norm_layer=self._norm_layer, norm_before=norm_before, ) self._context = self.in_block.context self._downsample_factor = self.in_block.stride cur_in_channels = in_conv_channels # middle blocks self.blocks = nn.ModuleList([]) for i in range(num_superblocks): repeats_i = self.resb_repeats[i] channels_i = self.resb_channels[i] stride_i = self.resb_strides[i] kernel_size_i = self.resb_kernel_sizes[i] dilation_i = self.resb_dilations[i] # if there is downsampling the dilation of the first block # is set to 1 dilation_i1 = dilation_i if stride_i == 1 else 1 if time_se: num_feats_i = int(self.in_feats / (self._downsample_factor * stride_i)) bargs["num_feats"] = num_feats_i block_i = self._block( cur_in_channels, channels_i, kernel_size_i, stride=stride_i, dilation=dilation_i1, groups=self.resb_groups, activation=hid_act, dropout_rate=dropout_rate, use_norm=use_norm, norm_layer=self._norm_layer, norm_before=norm_before, **bargs ) self.blocks.append(block_i) self._context += block_i.context * self._downsample_factor self._downsample_factor *= block_i.downsample_factor for j in range(repeats_i - 1): block_i = self._block( channels_i, channels_i, kernel_size_i, stride=1, dilation=dilation_i, groups=self.resb_groups, activation=hid_act, dropout_rate=dropout_rate, use_norm=use_norm, norm_layer=self._norm_layer, norm_before=norm_before, **bargs ) self.blocks.append(block_i) self._context += block_i.context * self._downsample_factor cur_in_channels = channels_i # head feature block if self.head_channels > 0: self.head_block = DC2dEncBlock( cur_in_channels, head_channels, kernel_size=1, stride=1, activation=head_act, use_norm=False, norm_before=norm_before, ) self._init_weights(hid_act)
def _init_weights(self, hid_act): if isinstance(hid_act, str): act_name = hid_act if isinstance(hid_act, dict): act_name = hid_act["name"] if act_name in ["relu6", "swish"]: act_name = "relu" for m in self.modules(): if isinstance(m, nn.Conv2d): try: nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity=act_name ) except: nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity="relu" ) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @staticmethod def _standarize_resblocks_param(p, num_blocks, p_name): if isinstance(p, int): p = [p] * num_blocks elif isinstance(p, list): if len(p) == 1: p = p * num_blocks assert len(p) == num_blocks, "len(%s)(%d)!=%d" % ( p_name, len(p), num_blocks, ) else: raise TypeError("wrong type for param {}={}".format(p_name, p)) return p def _compute_out_size(self, in_size): out_size = int((in_size - 1) // self.in_stride + 1) for stride in self.resb_strides: out_size = int((out_size - 1) // stride + 1) return out_size
[docs] def in_context(self): return (self._context, self._context)
[docs] def in_shape(self): return (None, self.in_channels, None, None)
[docs] def out_shape(self, in_shape=None): out_channels = ( self.head_channels if self.head_channels > 0 else self.resb_channels[-1] ) if in_shape is None: return (None, out_channels, None, None) assert len(in_shape) == 4 if in_shape[2] is None: H = None else: H = self._compute_out_size(in_shape[2]) if in_shape[3] is None: W = None else: W = self._compute_out_size(in_shape[3]) return (in_shape[0], out_chanels, H, W)
[docs] def forward(self, x): x = self.in_block(x) for idx, block in enumerate(self.blocks): x = block(x) if self.head_channels > 0: x = self.head_block(x) return x
[docs] def get_config(self): head_act = self.head_act hid_act = self.hid_act config = { "in_channels": self.in_channels, "in_conv_channels": self.in_conv_channels, "in_kernel_size": self.in_kernel_size, "in_stride": self.in_stride, "resb_type": self.resb_type, "resb_repeats": self.resb_repeats, "resb_channels": self.resb_channels, "resb_kernel_sizes": self.resb_kernel_sizes, "resb_strides": self.resb_strides, "resb_dilations": self.resb_dilations, "resb_groups": self.resb_groups, "head_channels": self.head_channels, "dropout_rate": self.dropout_rate, "hid_act": hid_act, "head_act": head_act, "se_r": self.se_r, "time_se": self.time_se, "res2net_width_factor": self.res2net_width_factor, "res2net_scale": self.res2net_scale, "use_norm": self.use_norm, "norm_layer": self.norm_layer, "norm_before": self.norm_before, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
[docs] @staticmethod def filter_args(**kwargs): if "wo_norm" in kwargs: kwargs["use_norm"] = not kwargs["wo_norm"] del kwargs["wo_norm"] if "norm_after" in kwargs: kwargs["norm_before"] = not kwargs["norm_after"] del kwargs["norm_after"] valid_args = ( "in_channels", "in_conv_channels", "in_kernel_size", "in_stride", "resb_type", "resb_repeats", "resb_channels", "resb_kernel_sizes", "resb_strides", "resb_dilations", "resb_groups", "head_channels", "se_r", "time_se", "res2net_width_factor", "res2net_scale", "hid_act", "had_act", "dropout_rate", "use_norm", "norm_layer", "norm_before", ) args = dict((k, kwargs[k]) for k in valid_args if k in kwargs) return args
[docs] @staticmethod def add_class_args(parser, prefix=None, skip=set()): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") parser.add_argument( "--in-channels", type=int, default=1, help=("input channel dimension") ) parser.add_argument( "--in-conv-channels", default=128, type=int, help=("number of output channels in input convolution"), ) parser.add_argument( "--in-kernel-size", default=3, type=int, help=("kernel size of input convolution"), ) parser.add_argument( "--in-stride", default=1, type=int, help=("stride of input convolution") ) parser.add_argument( "--resb-type", default="basic", choices=[ "basic", "bn", "sebasic", "sebn", "res2basic", "res2bn", "seres2basic", "sreres2bn", ], help=("residual blocks type"), ) parser.add_argument( "--resb-repeats", default=[1, 1, 1], type=int, nargs="+", help=("resb-blocks repeats in each encoder stage"), ) parser.add_argument( "--resb-channels", default=[128, 64, 32], type=int, nargs="+", help=("resb-blocks channels for each stage"), ) parser.add_argument( "--resb-kernel-sizes", default=3, nargs="+", type=int, help=("resb-blocks kernels for each encoder stage"), ) parser.add_argument( "--resb-strides", default=2, nargs="+", type=int, help=("resb-blocks strides for each encoder stage"), ) parser.add_argument( "--resb-dilations", default=[1], nargs="+", type=int, help=("resb-blocks dilations for each encoder stage"), ) parser.add_argument( "--resb-groups", default=1, type=int, help=("resb-blocks groups in convolutions"), ) if "head_channels" not in skip: parser.add_argument( "--head-channels", default=0, type=int, help=("channels in the last conv block of encoder"), ) try: parser.add_argument("--hid-act", default="relu6", help="hidden activation") except: pass parser.add_argument( "--head-act", default=None, help="activation in encoder head" ) try: parser.add_argument( "--dropout-rate", default=0, type=float, help="dropout probability" ) except: pass try: parser.add_argument( "--norm-layer", default=None, choices=[ "batch-norm", "group-norm", "instance-norm", "instance-norm-affine", "layer-norm", ], help="type of normalization layer", ) except: pass parser.add_argument( "--wo-norm", default=False, action="store_true", help="without batch normalization", ) parser.add_argument( "--norm-after", default=False, action="store_true", help="batch normalizaton after activation", ) parser.add_argument( "--se-r", default=16, type=int, help=("squeeze-excitation compression ratio"), ) parser.add_argument( "--time-se", default=False, action="store_true", help=("squeeze-excitation pooling is done in time dimension only"), ) parser.add_argument( "--res2net-width-factor", default=1, type=float, help=( "scaling factor for channels in middle layer " "of res2net bottleneck blocks" ), ) parser.add_argument( "--res2net-scale", default=1, type=float, help=("res2net scaling parameter "), ) if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='ResNet2d encoder options') add_argparse_args = add_class_args