Source code for hyperion.torch.layer_blocks.mbconv_blocks

"""
 Copyright 2020 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
#

import torch
import torch.nn as nn

# from torch.nn import Conv2d, BatchNorm2d

from ..layers import ActivationFactory as AF
from ..layers import DropConnect2d
from .se_blocks import SEBlock2D, TSEBlock2D


[docs]def _conv1x1(in_channels, out_channels, stride=1, bias=False): """1x1 convolution""" return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=bias)
[docs]def _dwconvkxk(channels, kernel_size=3, stride=1, bias=False): """kxk depth-wise convolution with padding""" return nn.Conv2d( channels, channels, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, padding_mode="zeros", )
def _make_downsample(in_channels, out_channels, stride, norm_layer): return nn.Sequential( _conv1x1(in_channels, out_channels, stride, bias=False), norm_layer(out_channels, momentum=0.01, eps=1e-3), )
[docs]class MBConvBlock(nn.Module):
[docs] def __init__( self, in_channels, out_channels, expansion=6, kernel_size=3, stride=1, activation="swish", drop_connect_rate=0, norm_layer=None, se_r=None, time_se=False, num_feats=None, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.expansion = expansion self.inner_channels = in_channels * expansion self.kernel_size = kernel_size self.act = AF.create(activation) self.stride = stride self.se_r = se_r self.has_se = se_r is not None and se_r > 1 self.time_se = time_se self.num_feats = num_feats if norm_layer is None: norm_layer = nn.BatchNorm2d # expansion phase if self.expansion > 1: self.conv_exp = _conv1x1(in_channels, self.inner_channels) self.bn_exp = norm_layer(self.inner_channels, momentum=0.01, eps=1e-3) # depthwise conv phase self.conv_dw = _dwconvkxk(self.inner_channels, self.kernel_size, stride) self.bn_dw = norm_layer(self.inner_channels, momentum=0.01, eps=1e-3) # squeeze-excitation block if self.has_se: if time_se: self.se_layer = TSEBlock2D( self.inner_channels, (num_feats + stride - 1) // stride, se_r, activation, ) else: self.se_layer = SEBlock2D(self.inner_channels, se_r, activation) # projection phase self.conv_proj = _conv1x1(self.inner_channels, out_channels) self.bn_proj = norm_layer(out_channels, momentum=0.01, eps=1e-3) self.drop_connect_rate = drop_connect_rate self.drop_connect = None if drop_connect_rate > 0: self.drop_connect = DropConnect2d(drop_connect_rate) # when input and output dimensions are different, we adapt the dimensions using conv1x1 # this is different from official implementation where they remove the residual connection self.downsample = None if stride != 1 or in_channels != out_channels: self.downsample = _make_downsample( in_channels, out_channels, stride, norm_layer ) self.context = stride * (kernel_size - 1) // 2 self.downsample_factor = stride
[docs] def forward(self, x): residual = x if self.expansion > 1: x = self.act(self.bn_exp(self.conv_exp(x))) x = self.act(self.bn_dw(self.conv_dw(x))) if self.has_se: x = self.se_layer(x) x = self.bn_proj(self.conv_proj(x)) if self.drop_connect_rate > 0: x = self.drop_connect(x) if self.downsample is not None: residual = self.downsample(residual) x += residual return x
[docs]class MBConvInOutBlock(nn.Module):
[docs] def __init__( self, in_channels, out_channels, kernel_size=3, stride=2, activation="swish", norm_layer=None, ): super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self.in_channels = in_channels self.out_channels = out_channels padding = int((kernel_size - 1) / 2) self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False, padding_mode="zeros", ) self.bn = norm_layer(out_channels, momentum=0.01, eps=1e-3) self.act = AF.create(activation) self.context = padding self.downsample_factor = stride
[docs] def forward(self, x): return self.act(self.bn(self.conv(x)))