Source code for hyperion.torch.models.xvectors.tdnn_xvector

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import logging
from jsonargparse import ArgumentParser, ActionParser

import torch
import torch.nn as nn

from .xvector import XVector
from ...narchs import TDNNFactory as TF


[docs]class TDNNXVector(XVector):
[docs] def __init__( self, tdnn_type, num_enc_blocks, in_feats, num_classes, enc_hid_units, enc_expand_units=None, kernel_size=3, dilation=1, dilation_factor=1, pool_net="mean+stddev", embed_dim=256, num_embed_layers=1, hid_act={"name": "relu6", "inplace": True}, loss_type="arc-softmax", s=64, margin=0.3, margin_warmup_epochs=0, num_subcenters=2, dropout_rate=0, norm_layer=None, head_norm_layer=None, use_norm=True, norm_before=False, in_norm=False, embed_layer=0, proj_feats=None, ): logging.info("making %s encoder network" % (tdnn_type)) encoder_net = TF.create( tdnn_type, num_enc_blocks, in_feats, enc_hid_units, enc_expand_units, kernel_size=kernel_size, dilation=dilation, dilation_factor=dilation_factor, hid_act=hid_act, dropout_rate=dropout_rate, norm_layer=norm_layer, use_norm=use_norm, norm_before=norm_before, in_norm=in_norm, ) super().__init__( encoder_net, num_classes, pool_net=pool_net, embed_dim=embed_dim, num_embed_layers=num_embed_layers, hid_act=hid_act, loss_type=loss_type, s=s, margin=margin, margin_warmup_epochs=margin_warmup_epochs, num_subcenters=num_subcenters, norm_layer=norm_layer, head_norm_layer=head_norm_layer, use_norm=use_norm, norm_before=norm_before, dropout_rate=dropout_rate, embed_layer=embed_layer, in_feats=None, proj_feats=proj_feats, ) self.tdnn_type = tdnn_type
@property def num_enc_blocks(self): return self.encoder_net.num_blocks @property def enc_hid_units(self): return self.encoder_net.hid_units @property def enc_expand_units(self): try: return self.encoder_net.expand_units except: return None @property def kernel_size(self): return self.encoder_net.kernel_size @property def dilation(self): return self.encoder_net.dilation @property def dilation_factor(self): return self.encoder_net.dilation_factor @property def in_norm(self): return self.encoder_net.in_norm
[docs] def get_config(self): base_config = super().get_config() del base_config["encoder_cfg"] pool_cfg = self.pool_net.get_config() config = { "tdnn_type": self.tdnn_type, "num_enc_blocks": self.num_enc_blocks, "in_feats": self.in_feats, "enc_hid_units": self.enc_hid_units, "enc_expand_units": self.enc_expand_units, "kernel_size": self.kernel_size, "dilation": self.dilation, "dilation_factor": self.dilation_factor, "in_norm": self.in_norm, } config.update(base_config) return config
[docs] @classmethod def load(cls, file_path=None, cfg=None, state_dict=None): cfg, state_dict = cls._load_cfg_state_dict(file_path, cfg, state_dict) model = cls(**cfg) if state_dict is not None: model.load_state_dict(state_dict) return model
[docs] def filter_args(**kwargs): base_args = XVector.filter_args(**kwargs) child_args = TF.filter_args(**kwargs) base_args.update(child_args) return base_args
[docs] @staticmethod def add_class_args(parser, prefix=None): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") XVector.add_class_args(parser) TF.add_class_args(parser) if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='xvector options') add_argparse_args = add_class_args