"""
Copyright 2020 Johns Hopkins University (Author: Jesus Villalba, Nanxin Chen)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
#
import torch
import torch.nn as nn
[docs]class SubPixelConv1d(nn.Module):
[docs] def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
):
super().__init__()
self.conv = nn.Conv1d(
in_channels,
stride * out_channels,
kernel_size,
stride=1,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
)
self.out_channels = out_channels
self.stride = stride
[docs] def forward(self, x):
x = self.conv(x)
if self.stride == 1:
return x
x = (
x.view(-1, self.stride, self.out_channels, x.size(-1))
.permute(0, 2, 3, 1)
.reshape(-1, self.out_channels, x.size(-1) * self.stride)
)
return x
[docs]class SubPixelConv2d(nn.Module):
[docs] def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
):
super().__init__()
self.conv = nn.Conv2d(
in_channels,
(stride ** 2) * out_channels,
kernel_size,
stride=1,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
)
self.stride = stride
if stride > 1:
self.pixel_shuffle = nn.PixelShuffle(self.stride)
[docs] def forward(self, x):
x = self.conv(x)
if self.stride == 1:
return x
return self.pixel_shuffle(x)
[docs]def ICNR2d(tensor, stride=2, initializer=nn.init.kaiming_normal):
"""Initialization method
"Initialization to Convolution Nearest neighbours Resize (ICNR)"
for subpixel convolutions described in
described in "Andrew Aitken et al. (2017) Checkerboard artifact free sub-pixel convolution"
https://arxiv.org/abs/1707.02937
Args:
tensor: torch.Tensor containing the conv weights
stride: subpixel conv stride
initializer: initizializer to be used for sub_kernel inizialization
Examples:
>>> conv = SubPixelConv2d(in_channels, out_channels, kernel_size=3, stride=upscale)
>>> ICNR2d(conv_shuffle.weight, stride=upscale)
"""
with torch.no_grad():
new_shape = [int(tensor.shape[0] / (stride ** 2))] + list(tensor.shape[1:])
subkernel = torch.zeros(new_shape)
subkernel = initializer(subkernel)
subkernel = subkernel.transpose(0, 1).contiguous()
subkernel = subkernel.view(subkernel.shape[0], subkernel.shape[1], -1)
kernel = subkernel.repeat(1, 1, stride ** 2)
transposed_shape = [tensor.shape[1], tensor.shape[0]] + list(tensor.shape[2:])
kernel = kernel.contiguous().view(transposed_shape).transpose(0, 1).contiguous()
tensor.copy_(kernel)
[docs]def ICNR1d(tensor, stride=2, initializer=nn.init.kaiming_normal):
"""1d version of the initialization method
"Initialization to Convolution Nearest neighbours Resize (ICNR)"
for subpixel convolutions described in
described in "Andrew Aitken et al. (2017) Checkerboard artifact free sub-pixel convolution"
https://arxiv.org/abs/1707.02937
Args:
tensor: torch.Tensor containing the conv weights
stride: subpixel conv stride
initializer: initizializer to be used for sub_kernel inizialization
Examples:
>>> conv = SubPixelConv1d(in_channels, out_channels, kernel_size=3, stride=upscale)
>>> ICNR1d(conv_shuffle.weight, stride=upscale)
"""
with torch.no_grad():
new_shape = [int(tensor.shape[0] / stride)] + list(tensor.shape[1:])
subkernel = torch.zeros(new_shape)
subkernel = initializer(subkernel)
subkernel = subkernel.transpose(0, 1).contiguous()
subkernel = subkernel.view(subkernel.shape[0], subkernel.shape[1], -1)
kernel = subkernel.repeat(1, 1, stride)
transposed_shape = (tensor.shape[1], tensor.shape[0], tensor.shape[2])
kernel = kernel.contiguous().view(transposed_shape).transpose(0, 1).contiguous()
tensor.copy_(kernel)