"""
Copyright 2019 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import logging
from jsonargparse import ArgumentParser, ActionParser
import torch
import torch.nn as nn
from .xvector import XVector
from ...narchs import ResNetFactory as RNF
[docs]class ResNetXVector(XVector):
[docs] def __init__(
self,
resnet_type,
in_feats,
num_classes,
in_channels,
conv_channels=64,
base_channels=64,
in_kernel_size=7,
in_stride=1,
zero_init_residual=False,
groups=1,
replace_stride_with_dilation=None,
do_maxpool=False,
pool_net="mean+stddev",
embed_dim=256,
num_embed_layers=1,
hid_act={"name": "relu", "inplace": True},
loss_type="arc-softmax",
s=64,
margin=0.3,
margin_warmup_epochs=0,
num_subcenters=2,
dropout_rate=0,
norm_layer=None,
head_norm_layer=None,
use_norm=True,
norm_before=True,
in_norm=False,
embed_layer=0,
proj_feats=None,
se_r=16,
res2net_scale=4,
res2net_width_factor=1,
):
logging.info("making %s encoder network", resnet_type)
encoder_net = RNF.create(
resnet_type,
in_channels,
conv_channels=conv_channels,
base_channels=base_channels,
hid_act=hid_act,
in_kernel_size=in_kernel_size,
in_stride=in_stride,
zero_init_residual=zero_init_residual,
groups=groups,
replace_stride_with_dilation=replace_stride_with_dilation,
dropout_rate=dropout_rate,
norm_layer=norm_layer,
norm_before=norm_before,
do_maxpool=do_maxpool,
in_norm=in_norm,
se_r=se_r,
in_feats=in_feats,
res2net_scale=res2net_scale,
res2net_width_factor=res2net_width_factor,
)
super().__init__(
encoder_net,
num_classes,
pool_net=pool_net,
embed_dim=embed_dim,
num_embed_layers=num_embed_layers,
hid_act=hid_act,
loss_type=loss_type,
s=s,
margin=margin,
margin_warmup_epochs=margin_warmup_epochs,
num_subcenters=num_subcenters,
norm_layer=norm_layer,
head_norm_layer=head_norm_layer,
use_norm=use_norm,
norm_before=norm_before,
dropout_rate=dropout_rate,
embed_layer=embed_layer,
in_feats=in_feats,
proj_feats=proj_feats,
)
self.resnet_type = resnet_type
@property
def in_channels(self):
return self.encoder_net.in_channels
@property
def conv_channels(self):
return self.encoder_net.conv_channels
@property
def base_channels(self):
return self.encoder_net.base_channels
@property
def in_kernel_size(self):
return self.encoder_net.in_kernel_size
@property
def in_stride(self):
return self.encoder_net.in_stride
@property
def zero_init_residual(self):
return self.encoder_net.zero_init_residual
@property
def groups(self):
return self.encoder_net.groups
@property
def replace_stride_with_dilation(self):
return self.encoder_net.replace_stride_with_dilation
@property
def do_maxpool(self):
return self.encoder_net.do_maxpool
@property
def in_norm(self):
return self.encoder_net.in_norm
@property
def se_r(self):
return self.encoder_net.se_r
@property
def res2net_scale(self):
return self.encoder_net.res2net_scale
@property
def res2net_width_factor(self):
return self.encoder_net.res2net_width_factor
[docs] def get_config(self):
base_config = super().get_config()
del base_config["encoder_cfg"]
pool_cfg = self.pool_net.get_config()
config = {
"resnet_type": self.resnet_type,
"in_channels": self.in_channels,
"conv_channels": self.conv_channels,
"base_channels": self.base_channels,
"in_kernel_size": self.in_kernel_size,
"in_stride": self.in_stride,
"zero_init_residual": self.zero_init_residual,
"groups": self.groups,
"replace_stride_with_dilation": self.replace_stride_with_dilation,
"do_maxpool": self.do_maxpool,
"in_norm": self.in_norm,
"se_r": self.se_r,
"res2net_scale": self.res2net_scale,
"res2net_width_factor": self.res2net_width_factor,
}
config.update(base_config)
return config
[docs] @classmethod
def load(cls, file_path=None, cfg=None, state_dict=None):
cfg, state_dict = cls._load_cfg_state_dict(file_path, cfg, state_dict)
model = cls(**cfg)
if state_dict is not None:
model.load_state_dict(state_dict)
return model
[docs] def filter_args(**kwargs):
base_args = XVector.filter_args(**kwargs)
child_args = RNF.filter_args(**kwargs)
base_args.update(child_args)
return base_args
[docs] @staticmethod
def add_class_args(parser, prefix=None):
if prefix is not None:
outer_parser = parser
parser = ArgumentParser(prog="")
XVector.add_class_args(parser)
RNF.add_class_args(parser)
if prefix is not None:
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='xvector options')
add_argparse_args = add_class_args