Source code for hyperion.torch.layer_blocks.dc2d_blocks

"""
 Copyright 2020 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import torch.nn as nn
from torch.nn import Conv2d, BatchNorm2d, Dropout2d

from ..layers import ActivationFactory as AF
from ..layers.subpixel_convs import SubPixelConv2d


[docs]class DC2dEncBlock(nn.Module):
[docs] def __init__( self, in_channels, out_channels, kernel_size, stride=1, dilation=1, activation="relu", dropout_rate=0, use_norm=True, norm_layer=None, norm_before=True, ): super().__init__() self.activation = AF.create(activation) padding = int(dilation * (kernel_size - 1) / 2) self.dropout_rate = dropout_rate self.dropout = None if dropout_rate > 0: self.dropout = Dropout2d(dropout_rate) self.norm_before = False self.norm_after = False if use_norm: if norm_layer is None: norm_layer = BatchNorm2d self.bn1 = norm_layer(out_channels) if norm_before: self.norm_before = True else: self.norm_after = True self.conv1 = Conv2d( in_channels, out_channels, bias=(not self.norm_before), kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, ) self.stride = stride self.context = dilation * (kernel_size - 1) // 2
[docs] def freeze(self): for param in self.parameters(): param.requires_grad = False
[docs] def unfreeze(self): for param in self.parameters(): param.requires_grad = True
[docs] def forward(self, x): x = self.conv1(x) if self.norm_before: x = self.bn1(x) if self.activation is not None: x = self.activation(x) if self.norm_after: x = self.bn1(x) if self.dropout_rate > 0: x = self.dropout(x) return x
[docs]class DC2dDecBlock(nn.Module):
[docs] def __init__( self, in_channels, out_channels, kernel_size, stride=1, dilation=1, activation="relu", dropout_rate=0, use_norm=True, norm_layer=None, norm_before=True, ): super().__init__() self.activation = AF.create(activation) padding = int(dilation * (kernel_size - 1) / 2) self.dropout_rate = dropout_rate self.dropout = None if dropout_rate > 0: self.dropout = Dropout2d(dropout_rate) self.norm_before = False self.norm_after = False if use_norm: if norm_layer is None: norm_layer = BatchNorm2d self.bn1 = norm_layer(out_channels) if norm_before: self.norm_before = True else: self.norm_after = True if stride == 1: self.conv1 = Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, bias=(not self.norm_before), padding=padding, ) # pytorch > 1.0 else: self.conv1 = SubPixelConv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, bias=(not self.norm_before), padding=padding, ) self.stride = stride self.context = dilation * (kernel_size - 1) // 2
[docs] def freeze(self): for param in self.parameters(): param.requires_grad = False
[docs] def unfreeze(self): for param in self.parameters(): param.requires_grad = True
[docs] def forward(self, x): x = self.conv1(x) if self.norm_before: x = self.bn1(x) if self.activation is not None: x = self.activation(x) if self.norm_after: x = self.bn1(x) if self.dropout_rate > 0: x = self.dropout(x) return x