"""
Copyright 2019 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import logging
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv1d, Linear, BatchNorm1d
from ..layers import ActivationFactory as AF
from ..layers import NormLayer2dFactory as NLF
from ..layer_blocks import (
ResNetInputBlock,
ResNetBasicBlock,
ResNetBNBlock,
SEResNetBasicBlock,
SEResNetBNBlock,
Res2NetBasicBlock,
Res2NetBNBlock,
)
from ..layer_blocks import ResNetEndpointBlock
from .net_arch import NetArch
[docs]class ResNet(NetArch):
"""ResNet2D base class
Attributes:
block: resnet basic block type in ['basic', 'bn', 'sebasic', 'sebn'], meaning
basic resnet block, bottleneck resnet block, basic block with squeeze-excitation,
and bottleneck block with squeeze-excitation
num_layers: list with the number of layers in each of the 4 layer blocks that we find in
resnets, after each layer block feature maps are downsmapled times 2 in each dimension
and channels are upsampled times 2.
in_channels: number of input channels
conv_channels: number of output channels in first conv layer (stem)
base_channels: number of channels in the first layer block
out_units: number of logits in the output layer, if 0 there is no output layer and resnet is used just
as feature extractor, for example for x-vector encoder.
in_kernel_size: kernels size of first conv layer
hid_act: str or dictionary describing hidden activations.
out_act: output activation
zero_init_residual: initializes batchnorm weights to zero so each residual block behaves as identitiy at
the beggining. We observed worse results when using this option in x-vectors
groups: number of groups in convolutions
replace_stride_with_dilation: use dialted conv nets instead of downsammpling, we never tested this.
dropout_rate: dropout rate
norm_layer: norm_layer object or str indicating type layer-norm object, if None it uses BatchNorm2d
do_maxpool: if False, removes the maxpooling layer at the stem of the network.
in_norm: if True, adds another batch norm layer in the input
se_r: squeeze-excitation dimension compression
time_se: if True squeeze-excitation embedding is obtaining by averagin only in the time dimension,
instead of time-freq dimension or HxW dimensions
in_feats: input feature size (number of components in dimension of 2 of input tensor), this is only
required when time_se=True to calculcate the size of the squeeze excitation matrices.
"""
[docs] def __init__(
self,
block,
num_layers,
in_channels,
conv_channels=64,
base_channels=64,
out_units=0,
hid_act={"name": "relu6", "inplace": True},
out_act=None,
in_kernel_size=7,
in_stride=2,
zero_init_residual=False,
multilevel=False,
endpoint_channels=64,
groups=1,
replace_stride_with_dilation=None,
dropout_rate=0,
norm_layer=None,
norm_before=True,
do_maxpool=True,
in_norm=True,
se_r=16,
time_se=False,
in_feats=None,
res2net_scale=4,
res2net_width_factor=1,
):
super().__init__()
logging.info("{}".format(locals()))
self.block = block
self.has_se = False
self.is_res2net = False
if isinstance(block, str):
if block == "basic":
self._block = ResNetBasicBlock
elif block == "bn":
self._block = ResNetBNBlock
elif block == "sebasic":
self._block = SEResNetBasicBlock
self.has_se = True
elif block == "sebn":
self._block = SEResNetBNBlock
self.has_se = True
elif block == "res2basic":
self._block = Res2NetBasicBlock
self.is_res2net = True
elif block == "res2bn":
self._block = Res2NetBNBlock
self.is_res2net = True
elif block == "seres2bn" or block == "tseres2bn":
self._block = Res2NetBNBlock
self.has_se = True
self.is_res2net = True
else:
self._block = block
self.num_layers = num_layers
self.in_channels = in_channels
self.conv_channels = conv_channels
self.base_channels = base_channels
self.out_units = out_units
self.in_kernel_size = in_kernel_size
self.in_stride = in_stride
self.hid_act = hid_act
self.groups = groups
self.norm_before = norm_before
self.do_maxpool = do_maxpool
self.in_norm = in_norm
self.dropout_rate = dropout_rate
# self.width_per_group = width_per_group
self.se_r = se_r
self.time_se = time_se
self.in_feats = in_feats
self.res2net_scale = res2net_scale
self.res2net_width_factor = res2net_width_factor
self.multilevel = multilevel
self.endpoint_channels = endpoint_channels
self.norm_layer = norm_layer
norm_groups = None
if norm_layer == "group-norm":
norm_groups = min(base_channels // 2, 32)
norm_groups = max(norm_groups, groups)
self._norm_layer = NLF.create(norm_layer, norm_groups)
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
)
self.replace_stride_with_dilation = replace_stride_with_dilation
self.groups = groups
# self.width_per_group = width_per_group
if in_norm:
self.in_bn = norm_layer(in_channels)
self.in_block = ResNetInputBlock(
in_channels,
conv_channels,
kernel_size=in_kernel_size,
stride=in_stride,
activation=hid_act,
norm_layer=self._norm_layer,
norm_before=norm_before,
do_maxpool=do_maxpool,
)
self._context = self.in_block.context
self._downsample_factor = self.in_block.downsample_factor
self.cur_in_channels = conv_channels
self.layer1 = self._make_layer(self._block, base_channels, num_layers[0])
self.layer2 = self._make_layer(
self._block,
2 * base_channels,
num_layers[1],
stride=2,
dilate=replace_stride_with_dilation[0],
)
self.layer3 = self._make_layer(
self._block,
4 * base_channels,
num_layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
)
self.layer4 = self._make_layer(
self._block,
8 * base_channels,
num_layers[3],
stride=2,
dilate=replace_stride_with_dilation[2],
)
if self.multilevel:
self.endpoint2 = ResNetEndpointBlock(
2 * base_channels * self._block.expansion,
self.endpoint_channels,
1,
activation=self.hid_act,
norm_layer=self._norm_layer,
norm_before=self.norm_before,
)
self.endpoint3 = ResNetEndpointBlock(
4 * base_channels * self._block.expansion,
self.endpoint_channels,
2,
activation=self.hid_act,
norm_layer=self._norm_layer,
norm_before=self.norm_before,
)
self.endpoint4 = ResNetEndpointBlock(
8 * base_channels * self._block.expansion,
self.endpoint_channels,
4,
activation=self.hid_act,
norm_layer=self._norm_layer,
norm_before=self.norm_before,
)
self.with_output = False
self.out_act = None
if out_units > 0:
self.with_output = True
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.output = nn.Linear(self.cur_in_channels, out_units)
self.out_act = AF.create(out_act)
for m in self.modules():
if isinstance(m, nn.Conv2d):
act_name = "relu"
if isinstance(hid_act, str):
act_name = hid_act
if isinstance(hid_act, dict):
act_name = hid_act["name"]
if act_name == "swish":
act_name = "relu"
try:
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity=act_name
)
except:
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
self.zero_init_residual = zero_init_residual
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBNBlock):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, ResNetBasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, channels, num_blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
kwargs = {}
if self.has_se:
if self.time_se:
num_feats = int(self.in_feats / (self._downsample_factor * stride))
kwargs = {"se_r": self.se_r, "time_se": True, "num_feats": num_feats}
else:
kwargs = {"se_r": self.se_r}
if self.is_res2net:
kwargs["scale"] = self.res2net_scale
kwargs["width_factor"] = self.res2net_width_factor
layers = []
layers.append(
block(
self.cur_in_channels,
channels,
activation=self.hid_act,
stride=stride,
dropout_rate=self.dropout_rate,
groups=self.groups,
dilation=previous_dilation,
norm_layer=self._norm_layer,
norm_before=self.norm_before,
**kwargs
)
)
self._context += layers[0].context * self._downsample_factor
self._downsample_factor *= layers[0].downsample_factor
self.cur_in_channels = channels * block.expansion
for _ in range(1, num_blocks):
layers.append(
block(
self.cur_in_channels,
channels,
activation=self.hid_act,
dropout_rate=self.dropout_rate,
groups=self.groups,
dilation=self.dilation,
norm_layer=self._norm_layer,
norm_before=self.norm_before,
**kwargs
)
)
self._context += layers[-1].context * self._downsample_factor
return nn.Sequential(*layers)
[docs] def _compute_out_size(self, in_size):
"""Computes output size given input size.
Output size is not the same as input size because of
downsampling steps.
Args:
in_size: input size of the H or W dimensions
Returns:
output_size
"""
out_size = int((in_size - 1) // self.in_stride + 1)
if self.do_maxpool:
out_size = int((out_size - 1) // 2 + 1)
for i in range(3):
if not self.replace_stride_with_dilation[i]:
out_size = int((out_size - 1) // 2 + 1)
return out_size
[docs] def in_context(self):
"""
Returns:
Tuple (past, future) context required to predict one frame.
"""
return (self._context, self._context)
[docs] def in_shape(self):
"""
Returns:
Tuple describing input shape for the network
"""
return (None, self.in_channels, None, None)
[docs] def out_shape(self, in_shape=None):
"""Computes the output shape given the input shape
Args:
in_shape: input shape
Returns:
Tuple describing output shape for the network
"""
if self.with_output:
return (None, self.out_units)
if in_shape is None:
return (None, self.layer4[-1].out_channels, None, None)
assert len(in_shape) == 4
if in_shape[2] is None:
H = None
else:
H = self._compute_out_size(in_shape[2])
if in_shape[3] is None:
W = None
else:
W = self._compute_out_size(in_shape[3])
if self.multilevel:
return (in_shape[0], self.endpoint_channels, int(in_shape[2] // 2), None)
return (in_shape[0], self.layer4[-1].out_channels, H, W)
[docs] def forward(self, x, use_amp=False):
if use_amp:
with torch.cuda.amp.autocast():
return self._forward(x)
return self._forward(x)
[docs] def _forward(self, x):
"""forward function
Args:
x: input tensor of size=(batch, Cin, Hin, Win) for image or
size=(batch, C, freq, time) for audio
Returns:
Tensor with output logits of size=(batch, out_units) if out_units>0,
otherwise, it returns tensor of represeantions of size=(batch, Cout, Hout, Wout)
"""
if self.in_norm:
x = self.in_bn(x)
feats = []
x = self.in_block(x)
x = self.layer1(x)
x = self.layer2(x)
if self.multilevel:
feats.append(x)
x = self.layer3(x)
if self.multilevel:
feats.append(x)
x = self.layer4(x)
if self.multilevel:
feats.append(x)
if self.multilevel:
out2 = self.endpoint2(feats[0])
out3 = self.endpoint3(feats[1])
out4 = self.endpoint4(feats[2])
x = torch.mean(torch.stack([out2, out3, out4]), 0)
if self.with_output:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.output(x)
if self.out_act is not None:
x = self.out_act(x)
return x
[docs] def forward_hid_feats(self, x, layers=None, return_output=False):
"""forward function which also returns intermediate hidden representations
Args:
x: input tensor of size=(batch, Cin, Hin, Win) for image or
size=(batch, C, freq, time) for audio
layers: list of hidden layers to return hidden representations
return_output: if True if returns the output representations in a separate
tensor.
Returns:
List of hidden representation tensors
Tensor with output representations if return_output is True
"""
assert layers is not None or return_output
if layers is None:
layers = []
if return_output:
last_layer = 4
else:
last_layer = max(layers)
h = []
feats = []
if self.in_norm:
x = self.in_bn(x)
x = self.in_block(x)
if 0 in layers:
h.append(x)
if last_layer == 0:
return h
x = self.layer1(x)
if 1 in layers:
h.append(x)
if last_layer == 1:
return h
x = self.layer2(x)
if 2 in layers:
h.append(x)
if last_layer == 2:
return h
if return_output and self.multilevel:
feats.append(x)
x = self.layer3(x)
if 3 in layers:
h.append(x)
if last_layer == 3:
return h
if return_output and self.multilevel:
feats.append(x)
x = self.layer4(x)
if 4 in layers:
h.append(x)
if return_output and self.multilevel:
feats.append(x)
if return_output:
if self.multilevel:
out2 = self.endpoint2(feats[0])
out3 = self.endpoint3(feats[1])
out4 = self.endpoint4(feats[2])
x = torch.mean(torch.stack([out2, out3, out4]), 0)
return h, x
return h
[docs] def get_config(self):
"""Gets network config
Returns:
dictionary with config params
"""
out_act = AF.get_config(self.out_act)
hid_act = self.hid_act
config = {
"block": self.block,
"num_layers": self.num_layers,
"in_channels": self.in_channels,
"conv_channels": self.conv_channels,
"base_channels": self.base_channels,
"out_units": self.out_units,
"in_kernel_size": self.in_kernel_size,
"in_stride": self.in_stride,
"zero_init_residual": self.zero_init_residual,
"groups": self.groups,
"replace_stride_with_dilation": self.replace_stride_with_dilation,
"dropout_rate": self.dropout_rate,
"norm_layer": self.norm_layer,
"norm_before": self.norm_before,
"in_norm": self.in_norm,
"do_maxpool": self.do_maxpool,
"out_act": out_act,
"hid_act": hid_act,
"se_r": self.se_r,
"in_feats": self.in_feats,
"res2net_scale": self.res2net_scale,
"res2net_width_factor": self.res2net_width_factor,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
# Standard ResNets
[docs]class ResNet18(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("basic", [2, 2, 2, 2], in_channels, **kwargs)
[docs]class ResNet34(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("basic", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class ResNet50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class ResNet101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class ResNet152(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("bn", [3, 8, 36, 3], in_channels, **kwargs)
[docs]class ResNext50_32x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 128
super().__init__("bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class ResNext101_32x8d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 256
super().__init__("bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class WideResNet50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
super().__init__("bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class WideResNet101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
super().__init__("bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class LResNet18(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
super().__init__("basic", [2, 2, 2, 2], in_channels, **kwargs)
[docs]class LResNet34(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
super().__init__("basic", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class LResNet50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
super().__init__("bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class LResNext50_4x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 4
kwargs["base_channels"] = 16
super().__init__("bn", [3, 4, 6, 3], in_channels, **kwargs)
# Squezee-Excitation ResNets
[docs]class SEResNet18(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("sebasic", [2, 2, 2, 2], in_channels, **kwargs)
[docs]class SEResNet34(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("sebasic", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SEResNet50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("sebn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SEResNet101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("sebn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class SEResNet152(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("sebn", [3, 8, 36, 3], in_channels, **kwargs)
[docs]class SEResNext50_32x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 128
super().__init__("sebn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SEResNext101_32x8d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 256
super().__init__("sebn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class SEWideResNet50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
super().__init__("sebn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SEWideResNet101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
super().__init__("sebn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class SELResNet18(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
super().__init__("sebasic", [2, 2, 2, 2], in_channels, **kwargs)
[docs]class SELResNet34(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
super().__init__("sebasic", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SELResNet50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
super().__init__("sebn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SELResNext50_4x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 4
kwargs["base_channels"] = 16
super().__init__("sebn", [3, 4, 6, 3], in_channels, **kwargs)
# Time dimension Squezee-Excitation ResNets
[docs]class TSEResNet18(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["time_se"] = True
super().__init__("sebasic", [2, 2, 2, 2], in_channels, **kwargs)
[docs]class TSEResNet34(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["time_se"] = True
super().__init__("sebasic", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSEResNet50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["time_se"] = True
super().__init__("sebn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSEResNet101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["time_se"] = True
super().__init__("sebn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class TSEResNet152(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["time_se"] = True
super().__init__("sebn", [3, 8, 36, 3], in_channels, **kwargs)
[docs]class TSEResNext50_32x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 128
kwargs["time_se"] = True
super().__init__("sebn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSEResNext101_32x8d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 256
kwargs["time_se"] = True
super().__init__("sebn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class TSEWideResNet50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
kwargs["time_se"] = True
super().__init__("sebn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSEWideResNet101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
kwargs["time_se"] = True
super().__init__("sebn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class TSELResNet18(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
kwargs["time_se"] = True
super().__init__("sebasic", [2, 2, 2, 2], in_channels, **kwargs)
[docs]class TSELResNet34(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
kwargs["time_se"] = True
super().__init__("sebasic", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSELResNet50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
kwargs["time_se"] = True
super().__init__("sebn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSELResNext50_4x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 4
kwargs["base_channels"] = 16
kwargs["time_se"] = True
super().__init__("sebn", [3, 4, 6, 3], in_channels, **kwargs)
#################### Res2Net variants ########################
# Standard Res2Nets
[docs]class Res2Net18(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("res2basic", [2, 2, 2, 2], in_channels, **kwargs)
[docs]class Res2Net34(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("res2basic", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class Res2Net50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("res2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class Res2Net101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("res2bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class Res2Net152(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("res2bn", [3, 8, 36, 3], in_channels, **kwargs)
[docs]class Res2Next50_32x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 128
super().__init__("res2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class Res2Next101_32x8d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 256
super().__init__("res2bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class WideRes2Net50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
super().__init__("res2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class WideRes2Net101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
super().__init__("res2bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class LRes2Net50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
super().__init__("res2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class LRes2Next50_4x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 4
kwargs["base_channels"] = 16
super().__init__("res2bn", [3, 4, 6, 3], in_channels, **kwargs)
# Squezee-Excitation Res2Nets
[docs]class SERes2Net18(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("seres2basic", [2, 2, 2, 2], in_channels, **kwargs)
[docs]class SERes2Net34(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("seres2basic", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SERes2Net50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("seres2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SERes2Net101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("seres2bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class SERes2Net152(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
super().__init__("seres2bn", [3, 8, 36, 3], in_channels, **kwargs)
[docs]class SERes2Next50_32x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 128
super().__init__("seres2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SERes2Next101_32x8d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 256
super().__init__("seres2bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class SEWideRes2Net50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
super().__init__("seres2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SEWideRes2Net101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
super().__init__("seres2bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class SELRes2Net50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
super().__init__("seres2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class SELRes2Next50_4x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 4
kwargs["base_channels"] = 16
super().__init__("seres2bn", [3, 4, 6, 3], in_channels, **kwargs)
# Time dimension Squezee-Excitation Res2Nets
[docs]class TSERes2Net18(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["time_se"] = True
super().__init__("se2basic", [2, 2, 2, 2], in_channels, **kwargs)
[docs]class TSERes2Net34(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["time_se"] = True
super().__init__("se2basic", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSERes2Net50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["time_se"] = True
super().__init__("seres2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSERes2Net101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["time_se"] = True
super().__init__("seres2bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class TSERes2Net152(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["time_se"] = True
super().__init__("seres2bn", [3, 8, 36, 3], in_channels, **kwargs)
[docs]class TSERes2Next50_32x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 128
kwargs["time_se"] = True
super().__init__("seres2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSERes2Next101_32x8d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 32
kwargs["base_channels"] = 256
kwargs["time_se"] = True
super().__init__("seres2bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class TSEWideRes2Net50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
kwargs["time_se"] = True
super().__init__("seres2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSEWideRes2Net101(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["base_channels"] = 128
kwargs["time_se"] = True
super().__init__("seres2bn", [3, 4, 23, 3], in_channels, **kwargs)
[docs]class TSELRes2Net50(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
kwargs["time_se"] = True
super().__init__("seres2bn", [3, 4, 6, 3], in_channels, **kwargs)
[docs]class TSELRes2Next50_4x4d(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["groups"] = 4
kwargs["base_channels"] = 16
kwargs["time_se"] = True
super().__init__("seres2bn", [3, 4, 6, 3], in_channels, **kwargs)
# multi-level feature ResNet
[docs]class LResNet34_345(ResNet):
[docs] def __init__(self, in_channels, **kwargs):
kwargs["conv_channels"] = 16
kwargs["base_channels"] = 16
kwargs["multilevel"] = True
kwargs["endpoint_channels"] = 64
super().__init__("basic", [3, 4, 6, 3], in_channels, **kwargs)