Source code for hyperion.torch.layer_blocks.spine_blocks

"""
 Copyright 2020 Magdalena Rybicka
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import torch.nn as nn
from torch.nn import Conv2d, BatchNorm2d, Dropout2d
import torch.nn.functional as nnf

from ..layers.subpixel_convs import SubPixelConv2d
from ..layers import ActivationFactory as AF

import logging


[docs]class Interpolate(nn.Module):
[docs] def __init__(self, scale_factor, mode="nearest"): super().__init__() self.interp = nnf.interpolate self.scale_factor = scale_factor self.mode = mode
[docs] def forward(self, x): x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode) return x
[docs]def _conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1, bias=False): """3x3 convolution with padding""" return nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=bias, dilation=dilation, )
[docs]def _conv1x1(in_channels, out_channels, stride=1, bias=False): """1x1 convolution""" return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=bias)
[docs]def _subpixel_conv1x1(in_channels, out_channels, stride=1, bias=False): """point-wise subpixel convolution""" return SubPixelConv2d( in_channels, out_channels, kernel_size=1, stride=stride, bias=bias )
def _make_downsample(in_channels, out_channels, stride, norm_layer, norm_before): if norm_before: return nn.Sequential( _conv3x3(in_channels, out_channels, stride, bias=False), norm_layer(out_channels), ) return _conv3x3(in_channels, out_channels, stride, bias=True) def _make_upsample(in_channels, out_channels, stride, norm_layer, norm_before): if norm_before: return nn.Sequential( _subpixel_conv1x1(in_channels, out_channels, stride, bias=False), norm_layer(out_channels), ) return _subpixel_conv1x1(in_channels, out_channels, stride, bias=True) def _make_resample( channels, scale, norm_layer, norm_before, activation, upsampling_type="nearest" ): resample_block = nn.ModuleList([]) if scale > 1: if upsampling_type == "subpixel": resample_block.append( _make_upsample(channels, channels, scale, norm_layer, norm_before) ) resample_block.append(AF.create(activation)) elif upsampling_type == "bilinear": resample_block.append(Interpolate(scale_factor=scale, mode="bilinear")) else: resample_block.append(Interpolate(scale_factor=scale, mode="nearest")) elif scale < 1: resample_block.append( _make_downsample(channels, channels, 2, norm_layer, norm_before) ) resample_block.append(AF.create(activation)) if scale < 0.5: new_kernel_size = 3 if scale >= 0.25 else 5 resample_block.append( nn.MaxPool2d( kernel_size=new_kernel_size, stride=int(0.5 / scale), padding=new_kernel_size // 2, ) ) return resample_block
[docs]class SpineConv(nn.Module):
[docs] def __init__( self, in_channels, channels, stride=1, dropout_rate=0, groups=1, dilation=1, activation={"name": "relu", "inplace": True}, norm_layer=None, norm_before=True, ): """ Class that connects the ouputs of the SpineNet to the rest of the network """ super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self.channels = channels self.norm_before = norm_before bias = not norm_before self.conv1 = _conv1x1(in_channels, channels, stride, bias=bias) self.bn1 = norm_layer(channels) self.act1 = AF.create(activation)
[docs] def forward(self, x): x = self.conv1(x) if self.norm_before: x = self.bn1(x) x = self.act1(x) return x
[docs]class BlockSpec(object): """A container class that specifies the block configuration for SpineNet."""
[docs] def __init__(self, level, block_fn, input_offsets, is_output): self.level = level self.block_fn = block_fn self.input_offsets = input_offsets self.is_output = is_output
[docs] @staticmethod def build_block_specs(block_specs=None): """Builds the list of BlockSpec objects for SpineNet.""" return [BlockSpec(*b) for b in block_specs]
[docs]class SpineEndpoints(nn.Module):
[docs] def __init__( self, in_channels, channels, level, target_level, upsampling_type="nearest", stride=1, activation={"name": "relu", "inplace": True}, norm_layer=None, norm_before=True, do_endpoint_conv=True, ): """ Class that connects the ouputs of the SpineNet to the rest of the network """ super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self.in_channels = in_channels self.channels = channels self.norm_before = norm_before self.scale = 2 ** (level - target_level) self.do_endpoint_conv = do_endpoint_conv self.upsampling_type = upsampling_type bias = not norm_before if self.do_endpoint_conv and in_channels != channels: # in some cases this convolution is not necessary self.conv1 = _conv1x1(in_channels, channels, stride, bias=bias) self.bn1 = norm_layer(channels) self.act1 = AF.create(activation) else: self.channels = in_channels self.resample = _make_resample( channels, self.scale, norm_layer, norm_before, activation, upsampling_type=upsampling_type, )
[docs] def forward(self, x): if self.do_endpoint_conv and self.in_channels != self.channels: x = self.conv1(x) if self.norm_before: x = self.bn1(x) x = self.act1(x) for mod in self.resample: x = mod(x) return x
[docs]class SpineResample(nn.Module):
[docs] def __init__( self, spec, in_channels, out_channels, scale, alpha, upsampling_type="nearest", activation={"name": "relu", "inplace": True}, norm_layer=None, norm_before=True, ): """ Class that build a resampling connection between single SpineNet blocks. """ super().__init__() self.spec = spec in_channels_alpha = int(in_channels * alpha) in_channels = in_channels * spec.block_fn.expansion self.scale = 2 ** scale bias = not norm_before self.norm_before = norm_before if norm_layer is None: norm_layer = BatchNorm2d self.conv1 = _conv1x1(in_channels, in_channels_alpha, bias=bias) self.bn1 = norm_layer(in_channels_alpha) self.act1 = AF.create(activation) self.resample = _make_resample( in_channels_alpha, self.scale, norm_layer, norm_before, activation, upsampling_type=upsampling_type, ) self.conv2 = _conv1x1(in_channels_alpha, out_channels, bias=bias) self.bn2 = norm_layer(out_channels)
[docs] def forward(self, x): x = self.conv1(x) if self.norm_before: x = self.bn1(x) x = self.act1(x) for mod in self.resample: x = mod(x) x = self.conv2(x) if self.norm_before: x = self.bn2(x) return x