Source code for hyperion.torch.narchs.spinenet

"""
 Copyright 2020 Magdalena Rybicka
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import Conv1d, Linear, BatchNorm1d

from ..layers import ActivationFactory as AF
from ..layers import NormLayer2dFactory as NLF
from ..layer_blocks import ResNetInputBlock, ResNetBasicBlock, ResNetBNBlock
from ..layer_blocks import Res2NetBNBlock, Res2NetBasicBlock
from ..layer_blocks import BlockSpec, SpineResample, SpineEndpoints, SpineConv
from .net_arch import NetArch

SPINENET_BLOCK_SPECS = [
    # level, block type, tuple of inputs, is output
    (2, ResNetBNBlock, (None, None), False),
    (2, ResNetBNBlock, (None, None), False),
    (2, ResNetBNBlock, (0, 1), False),
    (4, ResNetBasicBlock, (0, 1), False),
    (3, ResNetBNBlock, (2, 3), False),
    (4, ResNetBNBlock, (2, 4), False),
    (6, ResNetBasicBlock, (3, 5), False),
    (4, ResNetBNBlock, (3, 5), False),
    (5, ResNetBasicBlock, (6, 7), False),
    (7, ResNetBasicBlock, (6, 8), False),
    (5, ResNetBNBlock, (8, 9), False),
    (5, ResNetBNBlock, (8, 10), False),
    (4, ResNetBNBlock, (5, 10), True),
    (3, ResNetBNBlock, (4, 10), True),
    (5, ResNetBNBlock, (7, 12), True),
    (7, ResNetBNBlock, (5, 14), True),
    (6, ResNetBNBlock, (12, 14), True),
]

R0_SP53_BLOCK_SPECS = [
    # level, block type, tuple of inputs, is output
    (2, ResNetBNBlock, (None, None), False),  # 0
    (2, ResNetBNBlock, (None, None), False),  # 1
    (2, ResNetBNBlock, (0, 1), False),  # 2
    (3, ResNetBNBlock, (0, 1), False),  # 3
    (3, ResNetBNBlock, (2, 3), False),  # 4
    (4, ResNetBNBlock, (2, 4), False),  # 5
    (4, ResNetBNBlock, (3, 5), False),  # 6
    (3, ResNetBNBlock, (5, 6), False),  # 7
    (5, ResNetBNBlock, (4, 7), False),  # 8
    (4, ResNetBNBlock, (4, 8), False),  # 9
    (4, ResNetBNBlock, (8, 9), False),  # 10
    (4, ResNetBNBlock, (8, 10), False),  # 11
    (3, ResNetBNBlock, (4, 10), True),  # 12
    (4, ResNetBNBlock, (6, 7), True),  # 13 it has 3 inputs
    (5, ResNetBNBlock, (8, 13), True),  # 14
    (7, ResNetBNBlock, (6, 9), True),  # 15
    (6, ResNetBNBlock, (7, 9), True),  # 16
]

SPINENET_BLOCK_SPECS_5 = [
    # level, block type, tuple of inputs, is output
    (2, ResNetBNBlock, (None, None), False),  # 0
    (2, ResNetBNBlock, (None, None), False),  # 1
    (2, ResNetBNBlock, (0, 1), False),  # 2
    (4, ResNetBasicBlock, (0, 1), False),  # 3
    (3, ResNetBNBlock, (2, 3), False),  # 4
    (4, ResNetBNBlock, (2, 4), False),  # 5
    (6, ResNetBasicBlock, (3, 5), False),  # 6
    (4, ResNetBNBlock, (3, 5), False),  # 7
    (5, ResNetBasicBlock, (6, 7), False),  # 8
    (7, ResNetBasicBlock, (6, 8), False),  # 9
    (5, ResNetBNBlock, (8, 9), False),  # 10
    (5, ResNetBNBlock, (8, 10), False),  # 11
    (4, ResNetBNBlock, (5, 10), True),  # 12
    # (3, ResNetBNBlock, (4, 10), True),      # 13
    (5, ResNetBNBlock, (7, 12), True),  # 14
    # (7, ResNetBNBlock, (5, 14), True),      # 15
    # (6, ResNetBNBlock, (12, 14), True),     # 16
]

FILTER_SIZE_MAP = {
    # level: channel multiplier
    1: 0.5,
    2: 1,
    3: 2,
    4: 4,
    5: 4,
    6: 4,
    7: 4,
}


[docs]class SpineNet(NetArch):
[docs] def __init__( self, in_channels, block_specs=None, output_levels=[3, 4, 5, 6, 7], endpoints_num_filters=256, resample_alpha=0.5, feature_output_level=None, block_repeats=1, filter_size_scale=1.0, conv_channels=64, base_channels=64, out_units=0, concat=False, do_endpoint_conv=True, concat_ax=3, upsampling_type="nearest", hid_act={"name": "relu6", "inplace": True}, out_act=None, in_kernel_size=7, in_stride=2, zero_init_residual=False, groups=1, dropout_rate=0, norm_layer=None, norm_before=True, do_maxpool=True, in_norm=True, in_feats=None, se_r=16, time_se=False, has_se=False, is_res2net=False, res2net_scale=4, res2net_width_factor=1, ): """ Base class for the SpineNet structure. Based on the paper SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization Xianzhi Du, Tsung-Yi Lin, Pengchong Jin, Golnaz Ghiasi, Mingxing Tan, Yin Cui, Quoc V. Le, Xiaodan Song https://arxiv.org/abs/1912.05027 :param in_channels: nbr of channels of the input :param block_specs: specification of the building blocks: their type, input connections and information if block is an output :param output_levels: the output levels of the blocks that are taken as an output of the SpineNet :param endpoints_num_filters: the base number of channels out the SpineNet output :param resample_alpha: parameter for resampling connections :param concat: bool that decides wheter the outputs are concatenated or averaged :param do_endpoint_conv: bool that decides whether to do the projection of the output blocks to the common number of channels (the value of the number is the endpoints_num_filters) :param concat_ax: the axis along which perform the concatenation (if the concatenation is chosen) :param feature_output_level: the level that the output feature map sizes are resampled to (by default the target size is the biggest feature map) :param filter_size_scale: SpineNet parameter, that additionally rescales the number of channels of the SpineNet blocks(needed for bigger structures like SpineNet96 and higher or for SpineNet49S) """ super().__init__() self.in_channels = in_channels self.conv_channels = conv_channels self.base_channels = base_channels self.out_units = out_units self.endpoints_num_filters = endpoints_num_filters self.resample_alpha = resample_alpha self.block_repeats = block_repeats self.filter_size_scale = filter_size_scale self.concat = concat self.concat_ax = concat_ax self.do_endpoint_conv = do_endpoint_conv self.feature_output_level = ( min(output_levels) if feature_output_level is None else feature_output_level ) self.res2net_scale = res2net_scale self.res2net_width_factor = res2net_width_factor self.is_res2net = is_res2net self.se_r = se_r self.time_se = time_se self.has_se = has_se self._block_specs = ( BlockSpec.build_block_specs(SPINENET_BLOCK_SPECS) if block_specs is None else BlockSpec.build_block_specs(block_specs) ) self.output_levels = output_levels self.upsampling_type = upsampling_type self.dilation = 1 self.hid_act = hid_act self.in_kernel_size = in_kernel_size self.in_stride = in_stride self.groups = groups self.norm_before = norm_before self.do_maxpool = do_maxpool self.dropout_rate = dropout_rate self.in_norm = in_norm self.in_feats = in_feats self.norm_layer = norm_layer norm_groups = None if norm_layer == "group-norm": norm_groups = min(base_channels // 2, 32) norm_groups = max(norm_groups, groups) self._norm_layer = NLF.create(norm_layer, norm_groups) if in_norm: self.in_bn = norm_layer(in_channels) self.in_block = ResNetInputBlock( in_channels, conv_channels, kernel_size=in_kernel_size, stride=in_stride, activation=hid_act, norm_layer=self._norm_layer, norm_before=norm_before, do_maxpool=do_maxpool, ) if self.is_res2net: if self._block_specs[0].block_fn == ResNetBNBlock: _in_block = Res2NetBNBlock elif self._block_specs[0].block_fn == ResNetBasicBlock: _in_block = Res2NetBasicBlock else: _in_block = self._block_specs[0].block_fn self.stem0 = self._make_layer( _in_block, 2, self.block_repeats, in_channels=conv_channels ) self.stem1 = self._make_layer(_in_block, 2, self.block_repeats) self.stem_nbr = 2 # the number of the stem layers self.blocks = self._make_permuted_blocks(self._block_specs[self.stem_nbr :]) self.connections = self._make_permuted_connections( self._block_specs[self.stem_nbr :] ) self.endpoints = self._make_endpoints() self._context = self._compute_max_context(self.in_block.context) self._downsample_factor = self.in_block.downsample_factor * 2 ** ( self.feature_output_level - 2 ) self.with_output = False self.out_act = None if out_units > 0: self.with_output = True self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) cur_channels = self._compute_channel_size() self.output = nn.Linear(cur_channels, out_units) self.out_act = AF.create(out_act) for m in self.modules(): if isinstance(m, nn.Conv2d): act_name = "relu" if isinstance(hid_act, str): act_name = hid_act if isinstance(hid_act, dict): act_name = hid_act["name"] if act_name == "swish": act_name = "relu" try: nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity=act_name ) except: nn.init.kaiming_normal_( m.weight, mode="fan_out", nonlinearity="relu" ) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 self.zero_init_residual = zero_init_residual if zero_init_residual: for m in self.modules(): if isinstance(m, ResNetBNBlock): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, ResNetBasicBlock): nn.init.constant_(m.bn2.weight, 0)
[docs] def _make_permuted_blocks(self, block_specs): """ Builds the blocks of the SpineNet structure. """ blocks = nn.ModuleList([]) for block in block_specs: if self.is_res2net: if block.block_fn == ResNetBNBlock: _block = Res2NetBNBlock elif block.block_fn == ResNetBasicBlock: _block = Res2NetBasicBlock else: _block = block.block_fn else: _block = block.block_fn layer_i = self._make_layer(_block, block.level, self.block_repeats) blocks.append(layer_i) return blocks
[docs] def _make_permuted_connections(self, block_specs): """ Builds the cross-scale connections between the blocks. """ connections = nn.ModuleList([]) for block in block_specs: expansion = block.block_fn.expansion out_channels = ( int( FILTER_SIZE_MAP[block.level] * self.filter_size_scale * self.base_channels ) * expansion ) connections_i = nn.ModuleList([]) for i in block.input_offsets: offset_block = self._block_specs[i] scale = offset_block.level - block.level in_channels = int( FILTER_SIZE_MAP[offset_block.level] * self.filter_size_scale * self.base_channels ) connections_i.append( SpineResample( offset_block, in_channels, out_channels, scale, self.resample_alpha, self.upsampling_type, activation=self.hid_act, norm_layer=self._norm_layer, norm_before=self.norm_before, ) ) connections_i.append(AF.create(self.hid_act)) connections.append(connections_i) return connections
[docs] def _make_endpoints(self): """ Builds the output endpoint blocks. In this part, the block outputs are forwarded through the 1x1 convs to the common number of channels (endpoints_num_filters) and feature maps are resized to the size of the feature_output_level. """ endpoints = nn.ModuleDict() for block_spec in self._block_specs: if block_spec.is_output and block_spec.level in self.output_levels: expansion = block_spec.block_fn.expansion in_channels = ( int( FILTER_SIZE_MAP[block_spec.level] * self.filter_size_scale * self.base_channels ) * expansion ) out_channels = ( self.endpoints_num_filters if self.do_endpoint_conv else in_channels ) endpoints[str(block_spec.level)] = SpineEndpoints( in_channels, out_channels, block_spec.level, self.feature_output_level, self.upsampling_type, activation=self.hid_act, norm_layer=self._norm_layer, norm_before=self.norm_before, do_endpoint_conv=self.do_endpoint_conv, ) return endpoints
def _make_layer( self, block, block_level, num_blocks, in_channels=None, stride=1, dilate=False ): previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 kwargs = {} if self.has_se: if self.time_se: num_feats = int(self.in_feats / self.in_block.downsample_factor) for i in range(block_level - 2): num_feats = ( int(num_feats // 2) if num_feats % 2 == 0 else int(num_feats // 2 + 1) ) kwargs = {"se_r": self.se_r, "time_se": True, "num_feats": num_feats} else: kwargs = {"se_r": self.se_r} if self.is_res2net and block != ResNetBasicBlock: kwargs["scale"] = self.res2net_scale kwargs["width_factor"] = self.res2net_width_factor channels = int( FILTER_SIZE_MAP[block_level] * self.base_channels * self.filter_size_scale ) if in_channels is None: in_channels = channels * block.expansion layers = [] layers.append( block( in_channels, channels, activation=self.hid_act, stride=stride, dropout_rate=self.dropout_rate, groups=self.groups, dilation=previous_dilation, norm_layer=self._norm_layer, norm_before=self.norm_before, **kwargs ) ) cur_in_channels = channels * block.expansion for _ in range(1, num_blocks): layers.append( block( cur_in_channels, channels, activation=self.hid_act, dropout_rate=self.dropout_rate, groups=self.groups, dilation=self.dilation, norm_layer=self._norm_layer, norm_before=self.norm_before, **kwargs ) ) return nn.Sequential(*layers)
[docs] def _compute_max_context(self, in_context): """ Computes maximum possible context in the structure. The method may need a deeper revision. :param in_context: context from the input residual block. """ block_context = { # we can define specific values as inside the network the dilation or stride is not applied ResNetBNBlock: 1, ResNetBasicBlock: 2, } base_downsample_factor = self.in_block.downsample_factor context0 = in_context # context of the first two blocks (stem part) context0 += ( base_downsample_factor * block_context[self._block_specs[0].block_fn] * self.block_repeats ) context1 = ( context0 + base_downsample_factor * block_context[self._block_specs[1].block_fn] * self.block_repeats ) contexts = [context0, context1] # context in the scale permuted part num_outgoing_connections = [0, 0] for idx, block in enumerate(self._block_specs[self.stem_nbr :]): input0 = block.input_offsets[0] input1 = block.input_offsets[1] target_level = block.level # we add context if in the resampling connection was downsampling operation (it includes 3x3 convolution) resample0 = ( self._block_specs[input0].level + 1 if self._block_specs[input0].level - target_level < 0 else 0 ) resample1 = ( self._block_specs[input1].level + 1 if self._block_specs[input1].level - target_level < 0 else 0 ) parent0_context = contexts[input0] + resample0 parent1_context = contexts[input1] + resample1 # as input context we choose the input with higher value target_context = max(parent0_context, parent1_context) num_outgoing_connections[input0] += 1 num_outgoing_connections[input1] += 1 # Connect intermediate blocks with outdegree 0 to the output block. # Some blocks have also this additional connection if block.is_output: for j, j_connections in enumerate(num_outgoing_connections): if ( j_connections == 0 and self._block_specs[j].level == target_level ): target_context = max(contexts[j], target_context) num_outgoing_connections[j] += 1 downsample_factor = base_downsample_factor * 2 ** (target_level - 2) target_context += ( block_context[block.block_fn] * self.block_repeats * downsample_factor ) contexts.append(target_context) num_outgoing_connections.append(0) # logging.info('block\'s contexts: {}'.format(contexts)) return max(contexts)
[docs] def _compute_out_size(self, in_size): """Computes output size given input size. Output size is not the same as input size because of downsampling steps. Args: in_size: input size of the H or W dimensions Returns: output_size """ out_size = int((in_size - 1) // self.in_stride + 1) if self.do_maxpool: out_size = int((out_size - 1) // 2 + 1) downsample_levels = self.feature_output_level - 2 for i in range(downsample_levels): out_size = ( int(out_size // 2) if out_size % 2 == 0 else int(out_size // 2 + 1) ) return out_size
[docs] def _compute_channel_size(self): """ Returns: If the 1x1 conv is not conducted in the endpoint blocks, the number of channels is equal to the sum of the nbr of channels of the output blocks. """ if not self.do_endpoint_conv: C = 0 for output_level in self.output_levels: C += self.base_channels * 4 * FILTER_SIZE_MAP[output_level] return C else: if self.concat and self.concat_ax == 1: C = len(self.output_levels) * self.endpoints_num_filters return C return self.endpoints_num_filters
# def in_context(self): # """ # Returns: # Tuple (past, future) context required to predict one frame. # """ # return (self._context, self._context)
[docs] def in_shape(self): """ Returns: Tuple describing input shape for the network """ return (None, self.in_channels, None, None)
[docs] def out_shape(self, in_shape=None): """Computes the output shape given the input shape # # Args: # in_shape: input shape # Returns: # Tuple describing output shape for the network #""" if self.with_output: return (None, self.out_units) if in_shape is None: return (None, self.endpoints_num_filters, None, None) assert len(in_shape) == 4 if in_shape[2] is None: H = None else: H = self._compute_out_size(in_shape[2]) # in case of concatenation along feature dimension if self.concat_ax == 2 and self.concat: H = H * len(self.output_levels) if in_shape[3] is None: W = None else: W = self._compute_out_size(in_shape[3]) C = self._compute_channel_size() return (in_shape[0], C, H, W)
def _match_shape(self, x, target_shape): x_dim = x.dim() ddim = x_dim - len(target_shape) for i in range(2, x_dim): surplus = x.size(i) - target_shape[i - ddim] assert surplus >= 0 if surplus > 0: x = torch.narrow(x, i, surplus // 2, target_shape[i - ddim]) return x.contiguous()
[docs] def _match_feat_shape(self, feat0, feat1): """ Match shape between feats of the input connections. """ surplus = feat1.size(3) - feat0.size(3) if surplus >= 0: feat1 = self._match_shape(feat1, list(feat0.size())[2:]) else: feat0 = self._match_shape(feat0, list(feat1.size())[2:]) return feat0, feat1
[docs] def forward(self, x, use_amp=False): if use_amp: with torch.cuda.amp.autocast(): return self._forward(x) return self._forward(x)
[docs] def _forward(self, x): """forward function Args: x: input tensor of size=(batch, Cin, Hin, Win) for image or size=(batch, C, freq, time) for audio Returns: Tensor with output logits of size=(batch, out_units) if out_units>0, otherwise, it returns tensor of represeantions of size=(batch, Cout, Hout, Wout) """ if self.in_norm: x = self.in_bn(x) x = self.in_block(x) feat0 = self.stem0(x) feat1 = self.stem1(feat0) feats = [feat0, feat1] output_feats = {} num_outgoing_connections = [0, 0] for idx, block in enumerate(self._block_specs[self.stem_nbr :]): input0 = block.input_offsets[0] input1 = block.input_offsets[1] parent0_feat = self.connections[idx][0](feats[input0]) parent1_feat = self.connections[idx][1](feats[input1]) parent0_feat, parent1_feat = self._match_feat_shape( parent0_feat, parent1_feat ) target_feat = parent0_feat + parent1_feat num_outgoing_connections[input0] += 1 num_outgoing_connections[input1] += 1 # Connect intermediate blocks with outdegree 0 to the output block. if block.is_output: for j, (j_feat, j_connections) in enumerate( zip(feats, num_outgoing_connections) ): if j_connections == 0 and j_feat.shape == target_feat.shape: target_feat += j_feat num_outgoing_connections[j] += 1 target_feat = self.connections[idx][2]( target_feat ) # pass input through the activation function x = self.blocks[idx](target_feat) feats.append(x) num_outgoing_connections.append(0) if block.is_output and block.level in self.output_levels: if str(block.level) in output_feats: raise ValueError( "Duplicate feats found for output level {}.".format(block.level) ) output_feats[str(block.level)] = x output_endpoints = [] output_shape = list( output_feats[str(self.feature_output_level)].size() ) # get the target output size for endpoint in self.endpoints: if self.endpoints[endpoint] is not None: endpoint_i = self.endpoints[endpoint](output_feats[endpoint]) else: endpoint_i = output_feats[endpoint] endpoint_i = self._match_shape(endpoint_i, output_shape) output_endpoints.append(endpoint_i) if self.concat: x = torch.cat(output_endpoints, self.concat_ax) else: x = torch.mean(torch.stack(output_endpoints), 0) if self.with_output: x = self.avgpool(x) x = torch.flatten(x, 1) x = self.output(x) if self.out_act is not None: x = self.out_act(x) return x
[docs] def get_config(self): """Gets network config Returns: dictionary with config params """ out_act = AF.get_config(self.out_act) hid_act = self.hid_act config = { "in_channels": self.in_channels, "in_kernel_size": self.in_kernel_size, "in_stride": self.in_stride, "conv_channels": self.conv_channels, "base_channels": self.base_channels, "endpoints_num_filters": self.endpoints_num_filters, "resample_alpha": self.resample_alpha, "block_repeats": self.block_repeats, "filter_size_scale": self.filter_size_scale, "output_levels": self.output_levels, "feature_output_level": self.feature_output_level, "out_units": self.out_units, "concat": self.concat, "concat_ax": self.concat_ax, "do_endpoint_conv": self.do_endpoint_conv, "upsampling_type": self.upsampling_type, "zero_init_residual": self.zero_init_residual, "groups": self.groups, "dropout_rate": self.dropout_rate, "norm_layer": self.norm_layer, "norm_before": self.norm_before, "in_norm": self.in_norm, "do_maxpool": self.do_maxpool, "out_act": out_act, "hid_act": hid_act, "se_r": self.se_r, "in_feats": self.in_feats, "res2net_scale": self.res2net_scale, "res2net_width_factor": self.res2net_width_factor, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
# SpineNet structures from the original paper
[docs]class SpineNet49(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 256 kwargs["filter_size_scale"] = 1.0 kwargs["resample_alpha"] = 0.5 kwargs["block_repeats"] = 1 super(SpineNet49, self).__init__(in_channels, **kwargs)
[docs]class SpineNet49S(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 128 kwargs["filter_size_scale"] = 0.66 kwargs["resample_alpha"] = 0.5 kwargs["block_repeats"] = 1 super(SpineNet49S, self).__init__(in_channels, **kwargs)
[docs]class SpineNet96(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 256 kwargs["filter_size_scale"] = 1.0 kwargs["resample_alpha"] = 0.5 kwargs["block_repeats"] = 2 super(SpineNet96, self).__init__(in_channels, **kwargs)
[docs]class SpineNet143(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 256 kwargs["filter_size_scale"] = 1.0 kwargs["resample_alpha"] = 1.0 kwargs["block_repeats"] = 3 super(SpineNet143, self).__init__(in_channels, **kwargs)
[docs]class SpineNet190(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 512 kwargs["filter_size_scale"] = 1.3 kwargs["resample_alpha"] = 1.0 kwargs["block_repeats"] = 4 super(SpineNet190, self).__init__(in_channels, **kwargs)
# SpineNet modifications # Light SpineNets
[docs]class LSpineNet49(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 64 kwargs["conv_channels"] = 16 kwargs["base_channels"] = 16 super(LSpineNet49, self).__init__(in_channels, **kwargs)
[docs]class LSpineNet49_subpixel(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 64 kwargs["conv_channels"] = 16 kwargs["base_channels"] = 16 kwargs["upsampling_type"] = "subpixel" super(LSpineNet49_subpixel, self).__init__(in_channels, **kwargs)
[docs]class LSpineNet49_bilinear(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 64 kwargs["conv_channels"] = 16 kwargs["base_channels"] = 16 kwargs["upsampling_type"] = "bilinear" super(LSpineNet49_bilinear, self).__init__(in_channels, **kwargs)
[docs]class LSpineNet49_5(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 64 kwargs["conv_channels"] = 16 kwargs["base_channels"] = 16 kwargs["output_levels"] = [5] kwargs["do_endpoint_conv"] = False kwargs["block_specs"] = SPINENET_BLOCK_SPECS_5 super(LSpineNet49_5, self).__init__(in_channels, **kwargs)
[docs]class LSpine2Net49(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 64 kwargs["conv_channels"] = 16 kwargs["base_channels"] = 16 kwargs["is_res2net"] = True super(LSpine2Net49, self).__init__(in_channels, **kwargs)
# Spine2Nets ans(Time-)Squeeze-and-Excitation
[docs]class SELSpine2Net49(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 64 kwargs["conv_channels"] = 16 kwargs["base_channels"] = 16 kwargs["is_res2net"] = True kwargs["has_se"] = True super(SELSpine2Net49, self).__init__(in_channels, **kwargs)
[docs]class TSELSpine2Net49(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 64 kwargs["conv_channels"] = 16 kwargs["base_channels"] = 16 kwargs["is_res2net"] = True kwargs["has_se"] = True kwargs["time_se"] = True super(TSELSpine2Net49, self).__init__(in_channels, **kwargs)
[docs]class Spine2Net49(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["is_res2net"] = True super(Spine2Net49, self).__init__(in_channels, **kwargs)
[docs]class SESpine2Net49(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["is_res2net"] = True kwargs["has_se"] = True super(SESpine2Net49, self).__init__(in_channels, **kwargs)
[docs]class TSESpine2Net49(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["is_res2net"] = True kwargs["has_se"] = True kwargs["time_se"] = True super(TSESpine2Net49, self).__init__(in_channels, **kwargs)
[docs]class Spine2Net49S(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 128 kwargs["filter_size_scale"] = 0.66 kwargs["is_res2net"] = True super(Spine2Net49S, self).__init__(in_channels, **kwargs)
[docs]class SESpine2Net49S(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 128 kwargs["filter_size_scale"] = 0.66 kwargs["is_res2net"] = True kwargs["has_se"] = True super(SESpine2Net49S, self).__init__(in_channels, **kwargs)
[docs]class TSESpine2Net49S(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 128 kwargs["filter_size_scale"] = 0.66 kwargs["is_res2net"] = True kwargs["has_se"] = True kwargs["time_se"] = True super(TSESpine2Net49S, self).__init__(in_channels, **kwargs)
# R0-SP53 (structure from the paper)
[docs]class LR0_SP53(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["endpoints_num_filters"] = 64 kwargs["conv_channels"] = 16 kwargs["base_channels"] = 16 kwargs["block_specs"] = R0_SP53_BLOCK_SPECS super(LR0_SP53, self).__init__(in_channels, **kwargs)
[docs]class R0_SP53(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["block_specs"] = R0_SP53_BLOCK_SPECS super(R0_SP53, self).__init__(in_channels, **kwargs)
# concatenation
[docs]class SpineNet49_concat_time(SpineNet):
[docs] def __init__(self, in_channels, **kwargs): kwargs["concat"] = True super(SpineNet49_concat_time, self).__init__(in_channels, **kwargs)