Source code for hyperion.torch.layer_blocks.se_blocks

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import torch
import torch.nn as nn
from torch.nn import Conv2d, Conv1d

from ..layers import ActivationFactory as AF


[docs]class SEBlock2D(nn.Module): """From https://arxiv.org/abs/1709.01507"""
[docs] def __init__( self, num_channels, r=16, activation={"name": "relu", "inplace": True} ): super().__init__() self.conv1 = nn.Conv2d( num_channels, int(num_channels / r), kernel_size=1, bias=False ) self.act = AF.create(activation) self.conv2 = nn.Conv2d( int(num_channels / r), num_channels, kernel_size=1, bias=False ) self.sigmoid = nn.Sigmoid()
[docs] def forward(self, x): z = torch.mean(x, dim=(2, 3), keepdim=True) scale = self.sigmoid(self.conv2(self.act(self.conv1(z)))) y = scale * x return y
[docs]class TSEBlock2D(nn.Module): """From https://arxiv.org/abs/1709.01507 Modified to do pooling only in time dimension """
[docs] def __init__( self, num_channels, num_feats, r=16, activation={"name": "relu", "inplace": True}, ): super().__init__() self.num_channels_1d = num_channels * num_feats self.conv1 = nn.Conv2d( self.num_channels_1d, int(self.num_channels_1d / r), kernel_size=1, bias=False, ) self.act = AF.create(activation) self.conv2 = nn.Conv2d( int(self.num_channels_1d / r), self.num_channels_1d, kernel_size=1, bias=False, ) self.sigmoid = nn.Sigmoid()
[docs] def forward(self, x): num_feats = x.shape[2] num_channels = x.shape[1] z = torch.mean(x, dim=-1, keepdim=True) z = z.view(-1, self.num_channels_1d, 1, 1) scale = self.sigmoid(self.conv2(self.act(self.conv1(z)))) scale = scale.view(-1, num_channels, num_feats, 1) y = scale * x return y
[docs]class SEBlock1d(nn.Module): """1d Squeeze Excitation version of https://arxiv.org/abs/1709.01507 """
[docs] def __init__( self, num_channels, r=16, activation={"name": "relu", "inplace": True} ): super().__init__() self.conv1 = nn.Conv1d( num_channels, int(num_channels / r), kernel_size=1, bias=False ) self.act = AF.create(activation) self.conv2 = nn.Conv1d( int(num_channels / r), num_channels, kernel_size=1, bias=False ) self.sigmoid = nn.Sigmoid()
[docs] def forward(self, x): z = torch.mean(x, dim=2, keepdim=True) scale = self.sigmoid(self.conv2(self.act(self.conv1(z)))) y = scale * x return y
# aliases to mantein backwards compatibility SEBlock2d = SEBlock2D TSEBlock2d = TSEBlock2D