Source code for hyperion.torch.layers.subpixel_convs

"""
 Copyright 2020 Johns Hopkins University  (Author: Jesus Villalba, Nanxin Chen)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
#

import torch
import torch.nn as nn


[docs]class SubPixelConv1d(nn.Module):
[docs] def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode="zeros", ): super().__init__() self.conv = nn.Conv1d( in_channels, stride * out_channels, kernel_size, stride=1, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, ) self.out_channels = out_channels self.stride = stride
[docs] def forward(self, x): x = self.conv(x) if self.stride == 1: return x x = ( x.view(-1, self.stride, self.out_channels, x.size(-1)) .permute(0, 2, 3, 1) .reshape(-1, self.out_channels, x.size(-1) * self.stride) ) return x
[docs]class SubPixelConv2d(nn.Module):
[docs] def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode="zeros", ): super().__init__() self.conv = nn.Conv2d( in_channels, (stride ** 2) * out_channels, kernel_size, stride=1, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, ) self.stride = stride if stride > 1: self.pixel_shuffle = nn.PixelShuffle(self.stride)
[docs] def forward(self, x): x = self.conv(x) if self.stride == 1: return x return self.pixel_shuffle(x)
[docs]def ICNR2d(tensor, stride=2, initializer=nn.init.kaiming_normal): """Initialization method "Initialization to Convolution Nearest neighbours Resize (ICNR)" for subpixel convolutions described in described in "Andrew Aitken et al. (2017) Checkerboard artifact free sub-pixel convolution" https://arxiv.org/abs/1707.02937 Args: tensor: torch.Tensor containing the conv weights stride: subpixel conv stride initializer: initizializer to be used for sub_kernel inizialization Examples: >>> conv = SubPixelConv2d(in_channels, out_channels, kernel_size=3, stride=upscale) >>> ICNR2d(conv_shuffle.weight, stride=upscale) """ with torch.no_grad(): new_shape = [int(tensor.shape[0] / (stride ** 2))] + list(tensor.shape[1:]) subkernel = torch.zeros(new_shape) subkernel = initializer(subkernel) subkernel = subkernel.transpose(0, 1).contiguous() subkernel = subkernel.view(subkernel.shape[0], subkernel.shape[1], -1) kernel = subkernel.repeat(1, 1, stride ** 2) transposed_shape = [tensor.shape[1], tensor.shape[0]] + list(tensor.shape[2:]) kernel = kernel.contiguous().view(transposed_shape).transpose(0, 1).contiguous() tensor.copy_(kernel)
[docs]def ICNR1d(tensor, stride=2, initializer=nn.init.kaiming_normal): """1d version of the initialization method "Initialization to Convolution Nearest neighbours Resize (ICNR)" for subpixel convolutions described in described in "Andrew Aitken et al. (2017) Checkerboard artifact free sub-pixel convolution" https://arxiv.org/abs/1707.02937 Args: tensor: torch.Tensor containing the conv weights stride: subpixel conv stride initializer: initizializer to be used for sub_kernel inizialization Examples: >>> conv = SubPixelConv1d(in_channels, out_channels, kernel_size=3, stride=upscale) >>> ICNR1d(conv_shuffle.weight, stride=upscale) """ with torch.no_grad(): new_shape = [int(tensor.shape[0] / stride)] + list(tensor.shape[1:]) subkernel = torch.zeros(new_shape) subkernel = initializer(subkernel) subkernel = subkernel.transpose(0, 1).contiguous() subkernel = subkernel.view(subkernel.shape[0], subkernel.shape[1], -1) kernel = subkernel.repeat(1, 1, stride) transposed_shape = (tensor.shape[1], tensor.shape[0], tensor.shape[2]) kernel = kernel.contiguous().view(transposed_shape).transpose(0, 1).contiguous() tensor.copy_(kernel)