"""
Copyright 2019 Johns Hopkins University (Author: Jesus Villalba, Nanxin Chen)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import sys
import logging
import math
import torch
import torch.nn as nn
import torch.cuda.amp as amp
def _l2_norm(x, axis=-1):
with amp.autocast(enabled=False):
norm = torch.norm(x.float(), 2, axis, True) + 1e-10
y = torch.div(x, norm)
return y
[docs]class ArcLossOutput(nn.Module):
[docs] def __init__(self, in_feats, num_classes, s=64, margin=0.3, margin_warmup_epochs=0):
super().__init__()
self.in_feats = in_feats
self.num_classes = num_classes
self.s = s
self.margin = margin
self.margin_warmup_epochs = margin_warmup_epochs
if margin_warmup_epochs == 0:
self.cur_margin = margin
else:
self.cur_margin = 0
self._compute_aux()
self.kernel = nn.Parameter(torch.Tensor(in_feats, num_classes))
self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
def __repr__(self):
return self.__str__()
def __str__(self):
s = "%s(in_feats=%d, num_classes=%d, s=%.2f, margin=%.2f, margin_warmup_epochs=%d)" % (
self.__class__.__name__,
self.in_feats,
self.num_classes,
self.s,
self.margin,
self.margin_warmup_epochs,
)
return s
def _compute_aux(self):
logging.info("updating arc-softmax margin=%.2f" % (self.cur_margin))
self.cos_m = math.cos(self.cur_margin)
self.sin_m = math.sin(self.cur_margin)
[docs] def update_margin(self, epoch):
if self.margin_warmup_epochs == 0:
return
if epoch < self.margin_warmup_epochs:
self.cur_margin = self.margin * epoch / self.margin_warmup_epochs
else:
if self.cur_margin != self.margin:
self.cur_margin = self.margin
else:
return
self._compute_aux()
[docs] def forward(self, x, y=None):
with amp.autocast(enabled=False):
s = self.s
batch_size = len(x)
x = _l2_norm(x.float())
kernel_norm = _l2_norm(self.kernel, axis=0)
# cos(theta+m)
cos_theta = torch.mm(x, kernel_norm).float()
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
# print(cos_theta)
output = (
cos_theta * 1.0
) # a little bit hacky way to prevent in_place operation on cos_theta
if y is not None and self.training:
cos_theta_2 = torch.pow(cos_theta, 2)
sin_theta_2 = (1 + 1e-10) - cos_theta_2
sin_theta = torch.sqrt(sin_theta_2)
cos_theta_m = cos_theta * self.cos_m - sin_theta * self.sin_m
idx_ = torch.arange(0, batch_size, dtype=torch.long)
output[idx_, y] = cos_theta_m[idx_, y]
output *= s # scale up in order to make softmax work
return output
# @amp.float_function
# def forward(self, x, y=None):
# s = self.s
# #print(x)
# if len(x)==24:
# logging.info('x={}'.format(str(x[9])))
# x = _l2_norm(x)
# if len(x)==24:
# logging.info('xn={}'.format(str(x[9])))
# batch_size = len(x)
# kernel_norm = _l2_norm(self.kernel, axis=0)
# # cos(theta+m)
# cos_theta = torch.mm(x, kernel_norm).float()
# cos_theta = cos_theta.clamp(-1,1) # for numerical stability
# #print(cos_theta)
# output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
# if len(x)==24:
# logging.info('o={}'.format(str(output[9])))
# if y is not None and self.training and False:
# cos_theta_2 = torch.pow(cos_theta, 2)
# sin_theta_2 = (1 + 1e-10) - cos_theta_2
# sin_theta = torch.sqrt(sin_theta_2)
# cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
# idx_ = torch.arange(0, batch_size, dtype=torch.long)
# output[idx_, y] = cos_theta_m[idx_, y]
# #print(output)
# #sys.flush.stdout()
# #print(ss)
# output *= s # scale up in order to make softmax work
# if len(x)==24:
# logging.info('so={}'.format(str(output[9])))
# return output
[docs]class CosLossOutput(nn.Module):
[docs] def __init__(self, in_feats, num_classes, s=64, margin=0.3, margin_warmup_epochs=0):
super().__init__()
self.in_feats = in_feats
self.num_classes = num_classes
self.s = s
self.margin = margin
self.margin_warmup_epochs = margin_warmup_epochs
if margin_warmup_epochs == 0:
self.cur_margin = margin
else:
self.cur_margin = 0
self.kernel = nn.Parameter(torch.Tensor(in_feats, num_classes))
self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
[docs] def update_margin(self, epoch):
if self.margin_warmup_epochs == 0:
return
if epoch < self.margin_warmup_epochs:
self.cur_margin = self.margin * epoch / self.margin_warmup_epochs
logging.info("updating cos-softmax margin=%.2f" % (self.cur_margin))
else:
if self.cur_margin != self.margin:
self.cur_margin = self.margin
logging.info("updating cos-softmax margin=%.2f" % (self.cur_margin))
else:
return
[docs] def forward(self, x, y=None):
with amp.autocast(enabled=False):
s = self.s
x = _l2_norm(x.float())
batch_size = len(x)
kernel_norm = _l2_norm(self.kernel, axis=0)
# cos(theta+m)
cos_theta = torch.mm(x, kernel_norm).float()
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
output = (
cos_theta * 1.0
) # a little bit hacky way to prevent in_place operation on cos_theta
if y is not None and self.training:
cos_theta_m = cos_theta - self.cur_margin
idx_ = torch.arange(0, batch_size, dtype=torch.long)
output[idx_, y] = cos_theta_m[idx_, y]
output *= s # scale up in order to make softmax work
return output
[docs]class SubCenterArcLossOutput(ArcLossOutput):
[docs] def __init__(
self,
in_feats,
num_classes,
num_subcenters=2,
s=64,
margin=0.3,
margin_warmup_epochs=0,
):
super().__init__(
in_feats, num_classes * num_subcenters, s, margin, margin_warmup_epochs
)
self.num_classes = num_classes
self.num_subcenters = num_subcenters
def __str__(self):
s = "%s(in_feats=%d, num_classes=%d, num_subcenters=%d, s=%.2f, margin=%.2f, margin_warmup_epochs=%d)" % (
self.__class__.__name__,
self.in_feats,
self.num_classes,
self.num_subcenters,
self.s,
self.margin,
self.margin_warmup_epochs,
)
return s
[docs] def forward(self, x, y=None):
with amp.autocast(enabled=False):
s = self.s
batch_size = len(x)
x = _l2_norm(x.float())
kernel_norm = _l2_norm(self.kernel, axis=0)
# cos(theta+m)
cos_theta = torch.mm(x, kernel_norm).float()
cos_theta = torch.max(
cos_theta.view(-1, self.num_classes, self.num_subcenters), dim=-1
)[0]
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
# print(cos_theta)
output = (
cos_theta * 1.0
) # a little bit hacky way to prevent in_place operation on cos_theta
if y is not None and self.training:
cos_theta_2 = torch.pow(cos_theta, 2)
sin_theta_2 = (1 + 1e-10) - cos_theta_2
sin_theta = torch.sqrt(sin_theta_2)
cos_theta_m = cos_theta * self.cos_m - sin_theta * self.sin_m
idx_ = torch.arange(0, batch_size, dtype=torch.long)
output[idx_, y] = cos_theta_m[idx_, y]
output *= s # scale up in order to make softmax work
return output