"""
Copyright 2019 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
from jsonargparse import ArgumentParser, ActionParser
import torch.nn as nn
from torch.nn import Linear
from ..layers import CosLossOutput, ArcLossOutput, SubCenterArcLossOutput
from ..layers import NormLayer1dFactory as NLF
from ..layer_blocks import FCBlock
from .net_arch import NetArch
[docs]class ClassifHead(NetArch):
"""Classification Head for x-vector style networks
Attributes:
in_feats: input features
num_classes: number of output classes
embed_dim: dimension of embedding layer
num_embed_layers: number of hidden layers
hid_act: str or dict hidden activation type in ['relu', 'relu6', 'swish', ... ]
loss_type: type of loss function that will be used with the x-vector in ['softmax', 'cos-softmax', 'arc-softmax'],
corresponding to standard cross-entorpy, additive margin softmax or additive angular margin softmax.
s: scale parameter for cos-softmax and arc-softmax
margin: margin parameter for cos-softmax and arc-softmax
margin_warmup_epochs: number of epochs to anneal the margin from 0 to margin
num_subcenters: number of subcenters in subcenter losses
norm_layer: norm_layer object or str indicating type norm layer, if None it uses BatchNorm1d
use_norm: it True it uses layer/batch-normalization
norm_before: if True, layer-norm is before the activation function
"""
[docs] def __init__(
self,
in_feats,
num_classes,
embed_dim=256,
num_embed_layers=1,
hid_act={"name": "relu", "inplace": True},
loss_type="arc-softmax",
s=64,
margin=0.3,
margin_warmup_epochs=0,
num_subcenters=2,
norm_layer=None,
use_norm=True,
norm_before=True,
dropout_rate=0,
):
super().__init__()
assert num_embed_layers >= 1, "num_embed_layers (%d < 1)" % num_embed_layers
self.num_embed_layers = num_embed_layers
self.in_feats = in_feats
self.embed_dim = embed_dim
self.num_classes = num_classes
self.norm_layer = norm_layer
if use_norm:
norm_groups = None
if norm_layer == "group-norm":
norm_groups = min(embed_dim // 8, 32)
self._norm_layer = NLF.create(norm_layer, norm_groups)
else:
self._norm_layer = None
self.use_norm = use_norm
self.norm_before = norm_before
self.dropout_rate = dropout_rate
self.loss_type = loss_type
self.s = s
self.margin = margin
self.margin_warmup_epochs = margin_warmup_epochs
self.num_subcenters = num_subcenters
prev_feats = in_feats
fc_blocks = []
for i in range(num_embed_layers - 1):
fc_blocks.append(
FCBlock(
prev_feats,
embed_dim,
activation=hid_act,
dropout_rate=dropout_rate,
norm_layer=self._norm_layer,
use_norm=use_norm,
norm_before=norm_before,
)
)
prev_feats = embed_dim
if loss_type != "softmax":
act = None
else:
act = hid_act
fc_blocks.append(
FCBlock(
prev_feats,
embed_dim,
activation=act,
norm_layer=self._norm_layer,
use_norm=use_norm,
norm_before=norm_before,
)
)
self.fc_blocks = nn.ModuleList(fc_blocks)
# output layer
if loss_type == "softmax":
self.output = Linear(embed_dim, num_classes)
elif loss_type == "cos-softmax":
self.output = CosLossOutput(
embed_dim,
num_classes,
s=s,
margin=margin,
margin_warmup_epochs=margin_warmup_epochs,
)
elif loss_type == "arc-softmax":
self.output = ArcLossOutput(
embed_dim,
num_classes,
s=s,
margin=margin,
margin_warmup_epochs=margin_warmup_epochs,
)
elif loss_type == "subcenter-arc-softmax":
self.output = SubCenterArcLossOutput(
embed_dim,
num_classes,
num_subcenters,
s=s,
margin=margin,
margin_warmup_epochs=margin_warmup_epochs,
)
[docs] def rebuild_output_layer(
self, num_classes, loss_type, s, margin, margin_warmup_epochs, num_subcenters=2
):
embed_dim = self.embed_dim
self.num_classes = num_classes
self.loss_type = loss_type
self.s = s
self.margin = margin
self.margin_warmup_epochs = margin_warmup_epochs
self.num_subcenters = num_subcenters
if loss_type == "softmax":
self.output = Linear(embed_dim, num_classes)
elif loss_type == "cos-softmax":
self.output = CosLossOutput(
embed_dim,
num_classes,
s=s,
margin=margin,
margin_warmup_epochs=margin_warmup_epochs,
)
elif loss_type == "arc-softmax":
self.output = ArcLossOutput(
embed_dim,
num_classes,
s=s,
margin=margin,
margin_warmup_epochs=margin_warmup_epochs,
)
elif loss_type == "subcenter-arc-softmax":
self.output = SubCenterArcLossOutput(
embed_dim,
num_classes,
num_subcenters,
s=s,
margin=margin,
margin_warmup_epochs=margin_warmup_epochs,
)
[docs] def set_margin(self, margin):
if self.loss_type == "softmax":
return
self.margin = margin
self.output.margin = margin
[docs] def set_margin_warmup_epochs(self, margin_warmup_epochs):
if self.loss_type == "softmax":
return
self.margin_warmup_epochs = margin_warmup_epochs
self.output.margin_warmup_epochs = margin_warmup_epochs
[docs] def set_s(self, s):
if self.loss_type == "softmax":
return
self.s = s
self.output.s = s
[docs] def update_margin(self, epoch):
if hasattr(self.output, "update_margin"):
self.output.update_margin(epoch)
[docs] def freeze_layers(self, layer_list):
for l in layer_list:
for param in self.fc_blocks[l].parameters():
param.requires_grad = False
[docs] def put_layers_in_eval_mode(self, layer_list):
for l in layer_list:
self.fc_blocks[l].eval()
[docs] def forward(self, x, y=None):
for l in range(self.num_embed_layers):
x = self.fc_blocks[l](x)
if self.loss_type == "softmax":
y = self.output(x)
else:
y = self.output(x, y)
return y
[docs] def forward_hid_feats(self, x, y=None, layers=None, return_output=False):
assert layers is not None or return_output
if layers is None:
layers = []
h = []
for l in range(self.num_embed_layers):
x = self.fc_blocks[l](x)
if l in layers:
h.append(x)
if self.loss_type == "softmax":
y = self.output(x)
else:
y = self.output(x, y)
if return_output:
return h, y
return h
[docs] def get_config(self):
hid_act = AF.get_config(self.fc_blocks[0].activation)
config = {
"in_feats": self.in_feats,
"num_classes": self.num_classes,
"embed_dim": self.embed_dim,
"num_embed_layers": self.num_embed_layers,
"hid_act": hid_act,
"lost_type": self.lost_type,
"s": self.s,
"margin": self.margin,
"margin_warmup_epochs": self.margin_warmup_epochs,
"num_subcenters": self.num_subcenters,
"norm_layer": self.norm_layer,
"use_norm": self.use_norm,
"norm_before": self.norm_before,
"dropout_rate": self.dropout_rate,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs] @staticmethod
def filter_args(**kwargs):
if "wo_norm" in kwargs:
kwargs["use_norm"] = not kwargs["wo_norm"]
del kwargs["wo_norm"]
if "norm_after" in kwargs:
kwargs["norm_before"] = not kwargs["norm_after"]
del kwargs["norm_after"]
valid_args = (
"num_classes",
"embed_dim",
"num_embed_layers",
"hid_act",
"loss_type",
"s",
"margin",
"margin_warmup_epochs",
"num_subcenters",
"use_norm",
"norm_before",
"dropout_rate",
"norm_layer",
)
args = dict((k, kwargs[k]) for k in valid_args if k in kwargs)
return args
[docs] @staticmethod
def add_class_args(parser, prefix=None):
if prefix is not None:
outer_parser = parser
parser = ArgumentParser(prog="")
parser.add_argument(
"--embed-dim", default=256, type=int, help=("x-vector dimension")
)
parser.add_argument(
"--num-embed-layers",
default=1,
type=int,
help=("number of layers in the classif head"),
)
try:
parser.add_argument("--hid-act", default="relu6", help="hidden activation")
except:
pass
parser.add_argument(
"--loss-type",
default="arc-softmax",
choices=["softmax", "arc-softmax", "cos-softmax", "subcenter-arc-softmax"],
help="loss type: softmax, arc-softmax, cos-softmax, subcenter-arc-softmax",
)
parser.add_argument("--s", default=64, type=float, help="scale for arcface")
parser.add_argument(
"--margin", default=0.3, type=float, help="margin for arcface, cosface,..."
)
parser.add_argument(
"--margin-warmup-epochs",
default=10,
type=float,
help="number of epoch until we set the final margin",
)
parser.add_argument(
"--num-subcenters",
default=2,
type=int,
help="number of subcenters in subcenter losses",
)
try:
parser.add_argument(
"--norm-layer",
default=None,
choices=[
"batch-norm",
"group-norm",
"instance-norm",
"instance-norm-affine",
"layer-norm",
],
help="type of normalization layer for all components of x-vector network",
)
except:
pass
parser.add_argument(
"--wo-norm",
default=False,
action="store_true",
help="without batch normalization",
)
parser.add_argument(
"--norm-after",
default=False,
action="store_true",
help="batch normalizaton after activation",
)
try:
parser.add_argument("--dropout-rate", default=0, type=float, help="dropout")
except:
pass
if prefix is not None:
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='classification head options')
add_argparse_args = add_class_args