Source code for hyperion.torch.narchs.resnet1d_encoder

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

from jsonargparse import ArgumentParser, ActionParser
import math

import numpy as np

import torch
import torch.nn as nn

from ..layers import ActivationFactory as AF
from ..layers import NormLayer1dFactory as NLF
from ..layer_blocks import (
    ResNet1dBasicBlock,
    ResNet1dBNBlock,
    DC1dEncBlock,
    ResNet1dEndpoint,
    SEResNet1dBasicBlock,
    SEResNet1dBNBlock,
    Res2Net1dBasicBlock,
    Res2Net1dBNBlock,
)
from .net_arch import NetArch


[docs]class ResNet1dEncoder(NetArch):
[docs] def __init__( self, in_feats, in_conv_channels=128, in_kernel_size=3, in_stride=1, resb_type="basic", resb_repeats=[1, 1, 1], resb_channels=128, resb_kernel_sizes=3, resb_strides=2, resb_dilations=1, resb_groups=1, head_channels=0, hid_act="relu6", head_act=None, dropout_rate=0, drop_connect_rate=0, se_r=16, res2net_width_factor=1, res2net_scale=4, multilayer=False, multilayer_concat=False, endpoint_channels=None, endpoint_layers=None, endpoint_scale_layer=-1, use_norm=True, norm_layer=None, norm_before=True, upsampling_mode="nearest", ): super().__init__() self.resb_type = resb_type bargs = {} # block's extra arguments if resb_type == "basic": self._block = ResNet1dBasicBlock elif resb_type == "bn": self._block = ResNet1dBNBlock elif resb_type == "sebasic": self._block = SEResNet1dBasicBlock bargs["se_r"] = se_r elif resb_type == "sebn": self._block = SEResNet1dBNBlock bargs["se_r"] = se_r elif resb_type in ["res2basic", "seres2basic", "res2bn", "seres2bn"]: bargs["width_factor"] = res2net_width_factor bargs["scale"] = res2net_scale if resb_type in ["seres2basic", "seres2bn"]: bargs["se_r"] = se_r if resb_type in ["res2basic", "seres2basic"]: self._block = Res2Net1dBasicBlock else: self._block = Res2Net1dBNBlock self.in_feats = in_feats self.in_conv_channels = in_conv_channels self.in_kernel_size = in_kernel_size self.in_stride = in_stride num_superblocks = len(resb_repeats) self.resb_repeats = resb_repeats self.resb_channels = self._standarize_resblocks_param( resb_channels, num_superblocks, "resb_channels" ) self.resb_kernel_sizes = self._standarize_resblocks_param( resb_kernel_sizes, num_superblocks, "resb_kernel_sizes" ) self.resb_strides = self._standarize_resblocks_param( resb_strides, num_superblocks, "resb_strides" ) self.resb_dilations = self._standarize_resblocks_param( resb_dilations, num_superblocks, "resb_dilations" ) self.resb_groups = resb_groups self.head_channels = head_channels self.hid_act = hid_act self.head_act = head_act self.dropout_rate = dropout_rate self.drop_connect_rate = drop_connect_rate self.use_norm = use_norm self.norm_before = norm_before self.se_r = se_r self.res2net_width_factor = res2net_width_factor self.res2net_scale = res2net_scale self.norm_layer = norm_layer norm_groups = None if norm_layer == "group-norm": norm_groups = min(np.min(resb_channels) // 2, 32) norm_groups = max(norm_groups, resb_groups) self._norm_layer = NLF.create(norm_layer, norm_groups) # stem block self.in_block = DC1dEncBlock( in_feats, in_conv_channels, in_kernel_size, stride=in_stride, activation=hid_act, dropout_rate=dropout_rate, use_norm=use_norm, norm_layer=self._norm_layer, norm_before=norm_before, ) self._context = self.in_block.context self._downsample_factor = self.in_block.stride cur_in_channels = in_conv_channels total_blocks = np.sum(self.resb_repeats) # middle blocks self.blocks = nn.ModuleList([]) k = 0 self.resb_scales = [] for i in range(num_superblocks): blocks_i = nn.ModuleList([]) repeats_i = self.resb_repeats[i] channels_i = self.resb_channels[i] stride_i = self.resb_strides[i] kernel_size_i = self.resb_kernel_sizes[i] dilation_i = self.resb_dilations[i] # if there is downsampling the dilation of the first block # is set to 1 dilation_i1 = dilation_i if stride_i == 1 else 1 drop_i = drop_connect_rate * k / (total_blocks - 1) block_i1 = self._block( cur_in_channels, channels_i, kernel_size_i, stride=stride_i, dilation=dilation_i1, groups=self.resb_groups, activation=hid_act, dropout_rate=dropout_rate, drop_connect_rate=drop_i, use_norm=use_norm, norm_layer=self._norm_layer, norm_before=norm_before, **bargs, ) blocks_i.append(block_i1) k += 1 self._context += block_i1.context * self._downsample_factor self._downsample_factor *= block_i1.downsample_factor self.resb_scales.append(self._downsample_factor) for j in range(repeats_i - 1): drop_i = drop_connect_rate * k / (total_blocks - 1) block_ij = self._block( channels_i, channels_i, kernel_size_i, stride=1, dilation=dilation_i, groups=self.resb_groups, activation=hid_act, dropout_rate=dropout_rate, drop_connect_rate=drop_i, use_norm=use_norm, norm_layer=self._norm_layer, norm_before=norm_before, **bargs, ) blocks_i.append(block_ij) k += 1 self._context += block_ij.context * self._downsample_factor self.blocks.append(blocks_i) cur_in_channels = channels_i if multilayer: if endpoint_layers is None: # if is None all layers are endpoints endpoint_layers = [i + 1 for i in range(num_superblocks)] if endpoint_channels is None: # if None, the number of endpoint channels matches the one of the endpoint level endpoint_channels = self.resb_channels[endpoint_scale_layer] # which layers are enpoints self.is_endpoint = [ True if i + 1 in endpoint_layers else False for i in range(num_superblocks) ] # which endpoints have a projection layer ResNet1dEndpoint self.has_endpoint_block = [False] * num_superblocks # relates endpoint layers to their ResNet1dEndpoint object self.endpoint_block_idx = [0] * num_superblocks endpoint_scale = self.resb_scales[endpoint_scale_layer] endpoint_blocks = nn.ModuleList([]) cur_endpoint = 0 in_concat_channels = 0 for i in range(num_superblocks): if self.is_endpoint[i]: if multilayer_concat: out_channels = self.resb_channels[i] if self.resb_scales[i] != endpoint_scale: self.has_endpoint_block[i] = True # if self.resb_channels[i] != endpoint_channels: # out_channels = endpoint_channels # self.has_endpoint_block[i] = True in_concat_channels += out_channels else: self.has_endpoint_block[i] = True out_channels = endpoint_channels if self.has_endpoint_block[i]: endpoint_i = ResNet1dEndpoint( self.resb_channels[i], out_channels, in_scale=self.resb_scales[i], scale=endpoint_scale, activation=hid_act, upsampling_mode=upsampling_mode, norm_layer=self._norm_layer, norm_before=norm_before, ) self.endpoint_block_idx[i] = cur_endpoint endpoint_blocks.append(endpoint_i) cur_endpoint += 1 self.endpoint_blocks = endpoint_blocks if multilayer_concat: self.concat_endpoint_block = ResNet1dEndpoint( in_concat_channels, endpoint_channels, in_scale=1, scale=1, activation=hid_act, norm_layer=self._norm_layer, norm_before=norm_before, ) else: endpoint_channels = self.resb_channels[-1] self.multilayer = multilayer self.multilayer_concat = multilayer_concat self.endpoint_channels = endpoint_channels self.endpoint_layers = endpoint_layers self.endpoint_scale_layer = endpoint_scale_layer self.upsampling_mode = upsampling_mode # head feature block if self.head_channels > 0: self.head_block = DC1dEncBlock( cur_in_channels, head_channels, kernel_size=1, stride=1, activation=head_act, use_norm=False, norm_before=norm_before, ) self._init_weights(hid_act)
def _init_weights(self, hid_act): for m in self.modules(): if isinstance(m, nn.Conv1d): if isinstance(hid_act, str): act_name = hid_act if isinstance(hid_act, dict): act_name = hid_act["name"] if act_name == "swish": act_name = "relu" try: nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity=act_name ) except: nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity="relu" ) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @staticmethod def _standarize_resblocks_param(p, num_blocks, p_name): if isinstance(p, int): p = [p] * num_blocks elif isinstance(p, list): if len(p) == 1: p = p * num_blocks assert len(p) == num_blocks, "len(%s)(%d)!=%d" % ( p_name, len(p), num_blocks, ) else: raise TypeError("wrong type for param {}={}".format(p_name, p)) return p def _compute_out_size(self, in_size): out_size = int((in_size - 1) // self.in_stride + 1) if self.multilayer: strides = self.resb_strides[self.endpoint_scale_layer] else: strides = self.resb_strides for stride in strides: out_size = int((out_size - 1) // stride + 1) return out_size
[docs] def in_context(self): return (self._context, self._context)
[docs] def in_shape(self): return (None, self.in_feats, None)
[docs] def out_shape(self, in_shape=None): out_channels = ( self.head_channels if self.head_channels > 0 else self.endpoint_channels ) if in_shape is None: return (None, out_channels, None) assert len(in_shape) == 3 if in_shape[2] is None: T = None else: T = self._compute_out_size(in_shape[2]) return (in_shape[0], out_channels, T)
@staticmethod def _match_lens(endpoints): lens = [e.shape[-1] for e in endpoints] min_len = min(lens) for i in range(len(endpoints)): if lens[i] > min_len: t_start = (lens[i] - min_len) // 2 t_end = t_start + min_len endpoints[i] = endpoints[i][:, :, t_start:t_end] return endpoints
[docs] def forward(self, x): x = self.in_block(x) endpoints = [] for i, superblock in enumerate(self.blocks): for j, block in enumerate(superblock): x = block(x) if self.multilayer and self.is_endpoint[i]: endpoint_i = x if self.has_endpoint_block[i]: idx = self.endpoint_block_idx[i] endpoint_i = self.endpoint_blocks[idx](endpoint_i) endpoints.append(endpoint_i) if self.multilayer: endpoints = self._match_lens(endpoints) if self.multilayer_concat: try: x = torch.cat(endpoints, dim=1) except: for k in range(len(endpoints)): print("epcat ", k, endpoints[k].shape, flush=True) x = self.concat_endpoint_block(x) else: x = torch.mean(torch.stack(endpoints), 0) if self.head_channels > 0: x = self.head_block(x) return x
[docs] def forward_hid_feats(self, x, layers=None, return_output=False): assert layers is not None or return_output if layers is None: layers = [] if return_output: last_layer = len(self.blocks) + 1 else: last_layer = max(layers) h = [] x = self.in_block(x) if 0 in layers: h.append(x) endpoints = [] for i, superblock in enumerate(self.blocks): for j, block in enumerate(superblock): x = block(x) if i + 1 in layers: h.append(x) if return_output and self.multilayer and self.is_endpoint[i]: endpoint_i = x if self.has_endpoint_block[i]: idx = self.endpoint_block_idx[i] endpoint_i = self.endpoint_blocks[idx](endpoint_i) endpoints.append(endpoint_i) if last_layer == i + 1: break if not return_output: return h if self.multilayer: if self.multilayer_concat: x = torch.cat(endpoints, dim=1) x = self.concat_endpoint_block(x) else: x = torch.mean(torch.stack(endpoints), 0) if self.head_channels > 0: x = self.head_block(x) return x
[docs] def get_config(self): head_act = self.head_act hid_act = self.hid_act config = { "in_feats": self.in_feats, "in_conv_channels": self.in_conv_channels, "in_kernel_size": self.in_kernel_size, "in_stride": self.in_stride, "resb_type": self.resb_type, "resb_repeats": self.resb_repeats, "resb_channels": self.resb_channels, "resb_kernel_sizes": self.resb_kernel_sizes, "resb_strides": self.resb_strides, "resb_dilations": self.resb_dilations, "resb_groups": self.resb_groups, "head_channels": self.head_channels, "dropout_rate": self.dropout_rate, "drop_connect_rate": self.drop_connect_rate, "hid_act": hid_act, "head_act": head_act, "se_r": self.se_r, "res2net_width_factor": self.res2net_width_factor, "res2net_scale": self.res2net_scale, "use_norm": self.use_norm, "norm_layer": self.norm_layer, "norm_before": self.norm_before, "multilayer": self.multilayer, "multilayer_concat": self.multilayer_concat, "endpoint_channels": self.endpoint_channels, "endpoint_layers": self.endpoint_layers, "endpoint_scale_layer": self.endpoint_scale_layer, "upsampling_mode": self.upsampling_mode, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
[docs] @staticmethod def filter_args(**kwargs): if "wo_norm" in kwargs: kwargs["use_norm"] = not kwargs["wo_norm"] del kwargs["wo_norm"] if "norm_after" in kwargs: kwargs["norm_before"] = not kwargs["norm_after"] del kwargs["norm_after"] valid_args = ( "in_feats", "in_conv_channels", "in_kernel_size", "in_stride", "resb_type", "resb_repeats", "resb_channels", "resb_kernel_sizes", "resb_strides", "resb_dilations", "resb_groups", "head_channels", "se_r", "res2net_width_factor", "res2net_scale", "hid_act", "head_act", "dropout_rate", "drop_connect_rate", "use_norm", "norm_layer", "norm_before", "multilayer", "multilayer_concat", "endpoint_channels", "endpoint_layers", "endpoint_scale_layer", "upsampling_mode", ) args = dict((k, kwargs[k]) for k in valid_args if k in kwargs) return args
[docs] @staticmethod def add_class_args(parser, prefix=None, skip=set(["in_feats"])): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") if "in_feats" not in skip: parser.add_argument( "--in-feats", type=int, required=True, help=("input feature dimension") ) parser.add_argument( "--in-conv-channels", default=128, type=int, help=("number of output channels in input convolution"), ) parser.add_argument( "--in-kernel-size", default=3, type=int, help=("kernel size of input convolution"), ) parser.add_argument( "--in-stride", default=1, type=int, help=("stride of input convolution") ) parser.add_argument( "--resb-type", default="basic", choices=[ "basic", "bn", "sebasic", "sebn", "res2basic", "res2bn", "seres2basic", "seres2bn", ], help=("residual blocks type"), ) parser.add_argument( "--resb-repeats", default=[1, 1, 1], type=int, nargs="+", help=("resb-blocks repeats in each encoder stage"), ) parser.add_argument( "--resb-channels", default=[128, 64, 32], type=int, nargs="+", help=("resb-blocks channels for each stage"), ) parser.add_argument( "--resb-kernel-sizes", default=[3], nargs="+", type=int, help=("resb-blocks kernels for each encoder stage"), ) parser.add_argument( "--resb-strides", default=[2], nargs="+", type=int, help=("resb-blocks strides for each encoder stage"), ) parser.add_argument( "--resb-dilations", default=[1], nargs="+", type=int, help=("resb-blocks dilations for each encoder stage"), ) parser.add_argument( "--resb-groups", default=1, type=int, help=("resb-blocks groups in convolutions"), ) if "head_channels" not in skip: parser.add_argument( "--head-channels", default=0, type=int, help=("channels in the last conv block of encoder"), ) try: parser.add_argument("--hid-act", default="relu6", help="hidden activation") except: pass parser.add_argument( "--head-act", default=None, help="activation in encoder head" ) try: parser.add_argument( "--dropout-rate", default=0, type=float, help="dropout probability" ) except: pass try: parser.add_argument( "--drop-connect-rate", default=0, type=float, help="layer drop probability", ) except: pass try: parser.add_argument( "--norm-layer", default=None, choices=[ "batch-norm", "group-norm", "instance-norm", "instance-norm-affine", "layer-norm", ], help="type of normalization layer", ) except: pass parser.add_argument( "--wo-norm", default=False, action="store_true", help="without batch normalization", ) parser.add_argument( "--norm-after", default=False, action="store_true", help="batch normalizaton after activation", ) parser.add_argument( "--se-r", default=16, type=int, help=("squeeze-excitation compression ratio"), ) parser.add_argument( "--res2net-width-factor", default=1, type=float, help=( "scaling factor for channels in middle layer " "of res2net bottleneck blocks" ), ) parser.add_argument( "--res2net-scale", default=1, type=int, help=("res2net scaling parameter "), ) parser.add_argument( "--multilayer", default=False, action="store_true", help="use multilayer feature aggregation (mfa)", ) parser.add_argument( "--multilayer-concat", default=False, action="store_true", help="use concatenation for mfa", ) parser.add_argument( "--endpoint-channels", default=None, type=int, help=("num. endpoint channels when using mfa"), ) parser.add_argument( "--endpoint-layers", default=None, nargs="+", type=int, help=( "layers to aggreagate in mfa, " "if None, all residual blocks are aggregated" ), ) parser.add_argument( "--endpoint-scale-layer", default=-1, type=int, help=("layer number which indicates the time scale in mfa"), ) parser.add_argument( "--upsampling-mode", choices=["nearest", "bilinear", "subpixel"], default="nearest", help=("upsampling method when upsampling feature maps for mfa"), ) if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='ResNet1d encoder options') add_argparse_args = add_class_args