Source code for hyperion.torch.models.xvectors.resnet1d_xvector

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import logging
from jsonargparse import ArgumentParser, ActionParser

import torch
import torch.nn as nn

from .xvector import XVector
from ...narchs import ResNet1dEncoder as Encoder


[docs]class ResNet1dXVector(XVector):
[docs] def __init__( self, resnet_enc, num_classes, pool_net="mean+stddev", embed_dim=256, num_embed_layers=1, hid_act={"name": "relu", "inplace": True}, loss_type="arc-softmax", s=64, margin=0.3, margin_warmup_epochs=0, num_subcenters=2, dropout_rate=0, norm_layer=None, head_norm_layer=None, use_norm=True, norm_before=True, in_norm=False, embed_layer=0, proj_feats=None, ): if isinstance(resnet_enc, dict): logging.info("making %s resnet1d encoder network", resnet_enc["resb_type"]) resnet_enc = Encoder(**resnet_enc) super().__init__( resnet_enc, num_classes, pool_net=pool_net, embed_dim=embed_dim, num_embed_layers=num_embed_layers, hid_act=hid_act, loss_type=loss_type, s=s, margin=margin, margin_warmup_epochs=margin_warmup_epochs, num_subcenters=num_subcenters, norm_layer=norm_layer, head_norm_layer=head_norm_layer, use_norm=use_norm, norm_before=norm_before, dropout_rate=dropout_rate, embed_layer=embed_layer, proj_feats=proj_feats, )
# @property # def in_channels(self): # return self.encoder_net.in_channels # @property # def conv_channels(self): # return self.encoder_net.conv_channels # @property # def base_channels(self): # return self.encoder_net.base_channels # @property # def in_kernel_size(self): # return self.encoder_net.in_kernel_size # @property # def in_stride(self): # return self.encoder_net.in_stride # @property # def zero_init_residual(self): # return self.encoder_net.zero_init_residual # @property # def groups(self): # return self.encoder_net.groups # @property # def replace_stride_with_dilation(self): # return self.encoder_net.replace_stride_with_dilation # @property # def do_maxpool(self): # return self.encoder_net.do_maxpool # @property # def in_norm(self): # return self.encoder_net.in_norm # @property # def se_r(self): # return self.encoder_net.se_r # @property # def res2net_scale(self): # return self.encoder_net.res2net_scale # @property # def res2net_width_factor(self): # return self.encoder_net.res2net_width_factor
[docs] def get_config(self): base_config = super().get_config() del base_config["encoder_cfg"] del base_config["in_feats"] encoder_cfg = self.encoder_net.get_config() del encoder_cfg["class_name"] config = { "resnet_enc": encoder_cfg, } config.update(base_config) return config
[docs] @classmethod def load(cls, file_path=None, cfg=None, state_dict=None): cfg, state_dict = cls._load_cfg_state_dict(file_path, cfg, state_dict) try: # cfg["resnet_enc"] = cfg["encoder_net"] # del cfg["encoder_net"] # del cfg["in_feats"] # del cfg["resnet_enc"]["class_name"] # kk = { # "multilayer": True, # "multilayer_concat": True, # "endpoint_channels": 1536, # "endpoint_layers": None, # "endpoint_scale_layer": -1, # "upsampling_mode": "nearest", # } # cfg["resnet_enc"].update(kk) del cfg["in_feats"] except: pass print(cfg, flush=True) model = cls(**cfg) if state_dict is not None: model.load_state_dict(state_dict) return model
[docs] def filter_args(**kwargs): base_args = XVector.filter_args(**kwargs) child_args = Encoder.filter_args(**kwargs["resnet_enc"]) base_args["resnet_enc"] = child_args return base_args
[docs] @staticmethod def add_class_args(parser, prefix=None): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") XVector.add_class_args(parser, skip=set(["in_feats"])) Encoder.add_class_args(parser, prefix="resnet_enc", skip=set(["head_channels"])) # parser.link_arguments("in_feats", "resnet_enc.in_feats", apply_on="parse") # parser.link_arguments("norm_layer", "encoder_net.norm_layer", apply_on="parse") if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='xvector options') add_argparse_args = add_class_args