Source code for hyperion.torch.narchs.dc1d_encoder

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
from jsonargparse import ArgumentParser, ActionParser
import math

import torch
import torch.nn as nn

from ..layers import ActivationFactory as AF
from ..layers import NormLayer1dFactory as NLF
from ..layer_blocks.dc1d_blocks import DC1dEncBlock
from .net_arch import NetArch


[docs]class DC1dEncoder(NetArch):
[docs] def __init__( self, in_feats, in_conv_channels=128, in_kernel_size=3, in_stride=1, conv_repeats=[1, 1, 1], conv_channels=[128, 64, 32], conv_kernel_sizes=3, conv_strides=2, conv_dilations=1, head_channels=0, hid_act="relu6", head_act=None, dropout_rate=0, use_norm=True, norm_layer=None, norm_before=True, ): super().__init__() self.in_feats = in_feats self.in_conv_channels = in_conv_channels self.in_kernel_size = in_kernel_size self.in_stride = in_stride num_superblocks = len(conv_repeats) self.conv_repeats = conv_repeats self.conv_channels = self._standarize_convblocks_param( conv_channels, num_superblocks, "conv_channels" ) self.conv_kernel_sizes = self._standarize_convblocks_param( conv_kernel_sizes, num_superblocks, "conv_kernel_sizes" ) self.conv_strides = self._standarize_convblocks_param( conv_strides, num_superblocks, "conv_strides" ) self.conv_dilations = self._standarize_convblocks_param( conv_dilations, num_superblocks, "conv_dilations" ) self.head_channels = head_channels self.hid_act = hid_act self.head_act = head_act self.dropout_rate = dropout_rate self.use_norm = use_norm self.norm_before = norm_before self.norm_layer = norm_layer norm_groups = None if norm_layer == "group-norm": norm_groups = min(np.min(self.conv_channels) // 2, 32) self._norm_layer = NLF.create(norm_layer, norm_groups) # stem block self.in_block = DC1dEncBlock( in_feats, in_conv_channels, in_kernel_size, stride=in_stride, activation=hid_act, dropout_rate=dropout_rate, use_norm=use_norm, norm_layer=self._norm_layer, norm_before=norm_before, ) self._context = self.in_block.context self._downsample_factor = self.in_block.stride cur_in_channels = in_conv_channels # middle blocks self.blocks = nn.ModuleList([]) for i in range(num_superblocks): repeats_i = self.conv_repeats[i] channels_i = self.conv_channels[i] stride_i = self.conv_strides[i] kernel_size_i = self.conv_kernel_sizes[i] dilation_i = self.conv_dilations[i] block_i = DC1dEncBlock( cur_in_channels, channels_i, kernel_size_i, stride=stride_i, dilation=1, activation=hid_act, dropout_rate=dropout_rate, use_norm=use_norm, norm_layer=self._norm_layer, norm_before=norm_before, ) self.blocks.append(block_i) self._context += block_i.context * self._downsample_factor self._downsample_factor *= block_i.stride for j in range(repeats_i - 1): block_i = DC1dEncBlock( channels_i, channels_i, kernel_size_i, stride=1, dilation=dilation_i, activation=hid_act, dropout_rate=dropout_rate, use_norm=use_norm, norm_layer=self._norm_layer, norm_before=norm_before, ) self.blocks.append(block_i) self._context += block_i.context * self._downsample_factor cur_in_channels = channels_i # head feature block if self.head_channels > 0: self.head_block = DC1dEncBlock( cur_in_channels, head_channels, kernel_size=1, stride=1, activation=head_act, use_norm=False, norm_before=norm_before, ) self._init_weights(hid_act)
def _init_weights(self, hid_act): for m in self.modules(): if isinstance(m, nn.Conv1d): if isinstance(hid_act, str): act_name = hid_act if isinstance(hid_act, dict): act_name = hid_act["name"] if act_name == "swish": act_name = "relu" try: nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity=act_name ) except: nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity="relu" ) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @staticmethod def _standarize_convblocks_param(p, num_blocks, p_name): if isinstance(p, int): p = [p] * num_blocks elif isinstance(p, list): if len(p) == 1: p = p * num_blocks assert len(p) == num_blocks, "len(%s)(%d)!=%d" % ( p_name, len(p), num_blocks, ) else: raise TypeError("wrong type for param {}={}".format(p_name, p)) return p def _compute_out_size(self, in_size): out_size = int((in_size - 1) // self.in_stride + 1) for stride in self.conv_strides: out_size = int((out_size - 1) // stride + 1) return out_size
[docs] def in_context(self): return (self._context, self._context)
[docs] def in_shape(self): return (None, self.in_feats, None)
[docs] def out_shape(self, in_shape=None): out_channels = ( self.head_channels if self.head_channels > 0 else self.conv_channels[-1] ) if in_shape is None: return (None, out_channels, None) assert len(in_shape) == 3 if in_shape[2] is None: T = None else: T = self._compute_out_size(in_shape[2]) return (in_shape[0], out_chanels, T)
[docs] def forward(self, x): x = self.in_block(x) for idx, block in enumerate(self.blocks): x = block(x) if self.head_channels > 0: x = self.head_block(x) return x
[docs] def get_config(self): head_act = self.head_act hid_act = self.hid_act config = { "in_feats": self.in_feats, "in_conv_channels": self.in_conv_channels, "in_kernel_size": self.in_kernel_size, "in_stride": self.in_stride, "conv_repeats": self.conv_repeats, "conv_channels": self.conv_channels, "conv_kernel_sizes": self.conv_kernel_sizes, "conv_strides": self.conv_strides, "conv_dilations": self.conv_dilations, "head_channels": self.head_channels, "dropout_rate": self.dropout_rate, "hid_act": hid_act, "head_act": head_act, "use_norm": self.use_norm, "norm_layer": self.norm_layer, "norm_before": self.norm_before, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
[docs] @staticmethod def filter_args(**kwargs): if "wo_norm" in kwargs: kwargs["use_norm"] = not kwargs["wo_norm"] del kwargs["wo_norm"] if "norm_after" in kwargs: kwargs["norm_before"] = not kwargs["norm_after"] del kwargs["norm_after"] valid_args = ( "in_feats", "in_conv_channels", "in_kernel_size", "in_stride", "conv_repeats", "conv_channels", "conv_kernel_sizes", "conv_strides", "conv_dilations", "head_channels", "hid_act", "had_act", "dropout_rate", "use_norm", "norm_layer", "norm_before", ) args = dict((k, kwargs[k]) for k in valid_args if k in kwargs) return args
[docs] @staticmethod def add_class_args(parser, prefix=None, head_channels=False, in_feats=False): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") if in_feats: parser.add_argument( "--in-feats", type=int, required=True, help=("input feature dimension") ) parser.add_argument( "--in-conv-channels", default=128, type=int, help=("number of output channels in input convolution"), ) parser.add_argument( "--in-kernel-size", default=3, type=int, help=("kernel size of input convolution"), ) parser.add_argument( "--in-stride", default=1, type=int, help=("stride of input convolution") ) parser.add_argument( "--conv-repeats", default=[1, 1, 1], type=int, nargs="+", help=("conv-blocks repeats in each encoder stage"), ) parser.add_argument( "--conv-channels", default=[128, 64, 32], type=int, nargs="+", help=("conv-blocks channels for each stage"), ) parser.add_argument( "--conv-kernel-sizes", default=[3], nargs="+", type=int, help=("conv-blocks kernels for each encoder stage"), ) parser.add_argument( "--conv-strides", default=[2], nargs="+", type=int, help=("conv-blocks strides for each encoder stage"), ) parser.add_argument( "--conv-dilations", default=[1], nargs="+", type=int, help=("conv-blocks dilations for each encoder stage"), ) if head_channels: parser.add_argument( "--head-channels", default=16, type=int, help=("channels in the last conv block of encoder"), ) try: parser.add_argument("--hid-act", default="relu6", help="hidden activation") except: pass parser.add_argument( "--head-act", default=None, help="activation in encoder head" ) try: parser.add_argument( "--dropout-rate", default=0, type=float, help="dropout probability" ) except: pass try: parser.add_argument( "--norm-layer", default=None, choices=[ "batch-norm", "group-norm", "instance-norm", "instance-norm-affine", "layer-norm", ], help="type of normalization layer", ) except: pass parser.add_argument( "--wo-norm", default=False, action="store_true", help="without batch normalization", ) parser.add_argument( "--norm-after", default=False, action="store_true", help="batch normalizaton after activation", ) if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='DC1d encoder options') add_argparse_args = add_class_args