Source code for hyperion.torch.layer_blocks.transformer_conv2d_subsampler
"""
Copyright 2019 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import torch
import torch.nn as nn
[docs]class TransformerConv2dSubsampler(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length) Tor transformer
Attributes:
in_feats: input feature dimension
out_feats: Transformer d_model
hid_act: activation layer object
pos_enc: positional encoder layer
time_dim: indicates which is the time dimension in the input tensor
"""
[docs] def __init__(self, in_feats, out_feats, hid_act, pos_enc, time_dim=1):
super().__init__()
self.time_dim = time_dim
self.conv = nn.Sequential(
nn.Conv2d(1, out_feats, 3, 2, padding=(0, 1)),
hid_act,
nn.Conv2d(out_feats, out_feats, 3, 2, padding=(0, 1)),
hid_act,
)
self.out = nn.Sequential(
nn.Linear(out_feats * (((in_feats - 1) // 2 - 1) // 2), out_feats), pos_enc
)
[docs] def forward(self, x, mask):
"""Forward function.
Args:
x: input tensor with size=(batch, time, num_feats)
mask: mask to indicate valid time steps for x (batch, time1, time2)
Returns:
Tensor with output features
Tensor with subsampled mask
"""
if self.time_dim == 1:
x = x.transpose(1, 2)
x = x.unsqueeze(1) # (b, c, f, t)
x = self.conv(x)
b, c, f, t = x.size()
x = self.out(x.contiguous().view(b, c * f, t).transpose(1, 2))
if mask is None:
return x, None
return x, mask[:, :, :-2:2][:, :, :-2:2]