Source code for hyperion.torch.narchs.transformer_encoder_v1
"""
Copyright 2019 Johns Hopkins University (Author: Jesus Villalba, Nanxin Chen)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
from jsonargparse import ArgumentParser, ActionParser
import torch
import torch.nn as nn
from ..layers import ActivationFactory as AF
from ..layers import PosEncoder, RelPosEncoder
from ..layer_blocks import TransformerEncoderBlockV1 as EBlock
from ..layer_blocks import TransformerConv2dSubsampler as Conv2dSubsampler
from .net_arch import NetArch
[docs]class TransformerEncoderV1(NetArch):
"""Transformer encoder module.
Attributes:
in_feats: input features dimension
d_model: encoder blocks feature dimension
num_heads: number of heads
num_blocks: number of self attn blocks
att_type: string in ['scaled-dot-prod-att-v1', 'local-scaled-dot-prod-att-v1']
att_context: maximum context range for local attention
ff_type: string in ['linear', 'conv1dx2', 'conv1d-linear']
d_ff: dimension of middle layer in feed_forward block
ff_kernel_size: kernel size for convolutional versions of ff block
ff_dropout_rate: dropout rate for ff block
pos_dropout_rate: dropout rate for positional encoder
att_dropout_rate: dropout rate for attention block
in_layer_type: input layer block type in ['linear','conv2d-sub', 'embed', None]
rel_pos_enc: if True, use relative postional encodings, absolute encodings otherwise.
causal_pos_enc: if True, use causal positional encodings (when rel_pos_enc=True), it assumes
that query q_i only attents to key k_j when j<=i
hid_act: hidden activations in ff and input blocks
norm_before: if True, use layer norm before layers, otherwise after
concat_after: if True, if concats attention input and output and apply linear transform, i.e.,
y = x + linear(concat(x, att(x)))
if False, y = x + att(x)
padding_idx: padding idx for embed layer
in_time_dim: time dimension in the input Tensor
out_time_dim: dimension that we want to be time in the output tensor
"""
[docs] def __init__(
self,
in_feats,
d_model=256,
num_heads=4,
num_blocks=6,
att_type="scaled-dot-prod-v1",
att_context=25,
ff_type="linear",
d_ff=2048,
ff_kernel_size=1,
ff_dropout_rate=0.1,
pos_dropout_rate=0.1,
att_dropout_rate=0.0,
in_layer_type="conv2d-sub",
rel_pos_enc=False,
causal_pos_enc=False,
hid_act="relu6",
norm_before=True,
concat_after=False,
padding_idx=-1,
in_time_dim=-1,
out_time_dim=1,
):
super().__init__()
self.in_feats = in_feats
self.d_model = d_model
self.num_heads = num_heads
self.num_blocks = num_blocks
self.att_type = att_type
self.att_context = att_context
self.ff_type = ff_type
self.d_ff = d_ff
self.ff_kernel_size = ff_kernel_size
self.ff_dropout_rate = ff_dropout_rate
self.rel_pos_enc = rel_pos_enc
self.causal_pos_enc = causal_pos_enc
self.att_dropout_rate = att_dropout_rate
self.pos_dropout_rate = pos_dropout_rate
self.in_layer_type = in_layer_type
self.norm_before = norm_before
self.concat_after = concat_after
self.padding_idx = padding_idx
self.in_time_dim = in_time_dim
self.out_time_dim = out_time_dim
self.hid_act = hid_act
self._make_in_layer()
blocks = []
for i in range(num_blocks):
blocks.append(
EBlock(
d_model,
att_type,
num_heads,
ff_type,
d_ff,
ff_kernel_size,
ff_act=hid_act,
ff_dropout_rate=ff_dropout_rate,
att_context=att_context,
att_dropout_rate=att_dropout_rate,
rel_pos_enc=rel_pos_enc,
causal_pos_enc=causal_pos_enc,
norm_before=norm_before,
concat_after=concat_after,
)
)
self.blocks = nn.ModuleList(blocks)
if self.norm_before:
self.norm = nn.LayerNorm(d_model)
# def _make_in_layer(self, in_layer_type, in_feats, d_model,
# dropout_rate, pos_dropout_rate,
# padding_idx, time_dim):
def _make_in_layer(self):
in_feats = self.in_feats
d_model = self.d_model
dropout_rate = self.ff_dropout_rate
if self.rel_pos_enc:
pos_enc = RelPosEncoder(d_model, self.pos_dropout_rate)
else:
pos_enc = PosEncoder(d_model, self.pos_dropout_rate)
hid_act = AF.create(self.hid_act)
if self.in_layer_type == "linear":
self.in_layer = nn.Sequential(
nn.Linear(in_feats, d_model),
nn.LayerNorm(d_model),
nn.Dropout(dropout_rate),
hid_act,
pos_enc,
)
elif self.in_layer_type == "conv2d-sub":
self.in_layer = Conv2dSubsampler(
in_feats, d_model, hid_act, pos_enc, time_dim=self.in_time_dim
)
elif self.in_layer_type == "embed":
self.in_layer = nn.Sequential(
nn.Embedding(in_feats, d_model, padding_idx=self.padding_idx), pos_enc
)
elif isinstance(self.in_layer_type, nn.Module):
self.in_layer = nn.Sequential(in_layer_type, pos_enc)
elif self.in_layer_type is None:
self.in_layer = pos_enc
else:
raise ValueError("unknown in_layer_type: " + self.in_layer_type)
[docs] def forward(self, x, mask=None, target_shape=None, use_amp=False):
if use_amp:
with torch.cuda.amp.autocast():
return self._forward(x, mask, target_shape)
return self._forward(x, mask, target_shape)
[docs] def _forward(self, x, mask=None, target_shape=None):
"""Forward pass function
Args:
x: input tensor with size=(batch, time, num_feats)
mask: mask to indicate valid time steps for x (batch, time)
Returns:
Tensor with output features
Tensor with mask
"""
if isinstance(self.in_layer, Conv2dSubsampler):
x, mask = self.in_layer(x, mask)
else:
if self.in_time_dim != 1:
x = x.transpose(1, self.in_time_dim).contiguous()
x = self.in_layer(x)
if isinstance(x, tuple):
x, pos_emb = x
b_args = {"pos_emb": pos_emb}
else:
b_args = {}
for i in range(len(self.blocks)):
x, mask = self.blocks[i](x, mask=mask, **b_args)
if self.norm_before:
x = self.norm(x)
if self.out_time_dim != 1:
x = x.transpose(1, self.out_time_dim)
if mask is None:
return x
return x, mask
[docs] def get_config(self):
"""Gets network config
Returns:
dictionary with config params
"""
config = {
"in_feats": self.in_feats,
"d_model": self.d_model,
"num_heads": self.num_heads,
"num_blocks": self.num_blocks,
"att_type": self.att_type,
"att_context": self.att_context,
"ff_type": self.ff_type,
"d_ff": self.d_ff,
"ff_kernel_size": self.ff_kernel_size,
"ff_dropout_rate": self.ff_dropout_rate,
"att_dropout_rate": self.att_dropout_rate,
"pos_dropout_rate": self.pos_dropout_rate,
"in_layer_type": self.in_layer_type,
"rel_pos_enc": self.rel_pos_enc,
"causal_pos_enc": self.causal_pos_enc,
"hid_act": self.hid_act,
"norm_before": self.norm_before,
"concat_after": self.concat_after,
"padding_idx": self.padding_idx,
"in_time_dim": self.in_time_dim,
"out_time_dim": self.out_time_dim,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs] def in_shape(self):
"""Input shape for network
Returns:
Tuple describing input shape
"""
if self.in_time_dim == 1:
return (None, None, self.in_feats)
else:
return (None, self.in_feats, None)
[docs] def out_shape(self, in_shape=None):
"""Infers the network output shape given the input shape
Args:
in_shape: input shape tuple
Returns:
Tuple with the output shape
"""
if in_shape is None:
out_t = None
batch_size = None
else:
assert len(in_shape) == 3
batch_size = in_shape[0]
in_t = in_shape[self.in_time_dim]
if in_t is None:
out_t = None
else:
if isinstance(self.in_layer, Conv2dSubsampler):
# out_t = in_t//4
out_t = ((in_t - 1) // 2 - 1) // 2
else:
out_t = in_t
if self.out_time_dim == 1:
return (batch_size, out_t, self.d_model)
else:
return (batch_size, self.d_model, out_t)
[docs] @staticmethod
def filter_args(**kwargs):
"""Filters arguments correspondin to TransformerXVector
from args dictionary
Args:
kwargs: args dictionary
Returns:
args dictionary
"""
valid_args = (
"num_blocks",
"in_feats",
"d_model",
"num_heads",
"att_type",
"att_context",
"ff_type",
"d_ff",
"ff_kernel_size",
"ff_dropout_rate",
"pos_dropout_rate",
"att_dropout_rate",
"in_layer_type",
"hid_act",
"rel_pos_enc",
"causal_pos_enc",
"concat_after",
)
return dict((k, kwargs[k]) for k in valid_args if k in kwargs)
[docs] @staticmethod
def add_class_args(parser, prefix=None, in_feats=False):
"""Adds Transformer config parameters to argparser
Args:
parser: argparse object
prefix: prefix string to add to the argument names
"""
if prefix is not None:
outer_parser = parser
parser = ArgumentParser(prog="")
if in_feats:
parser.add_argument(
"--in-feats", type=int, default=80, help=("input feature dimension")
)
parser.add_argument(
"--num-blocks", default=6, type=int, help=("number of tranformer blocks")
)
parser.add_argument(
"--d-model", default=512, type=int, help=("encoder layer sizes")
)
parser.add_argument(
"--num-heads",
default=4,
type=int,
help=("number of heads in self-attention layers"),
)
parser.add_argument(
"--att-type",
default="scaled-dot-prod-v1",
choices=["scaled-dot-prod-v1", "local-scaled-dot-prod-v1"],
help=("type of self-attention"),
)
parser.add_argument(
"--att-context",
default=25,
type=int,
help=("context size when using local attention"),
)
parser.add_argument(
"--ff-type",
default="linear",
choices=["linear", "conv1dx2", "conv1dlinear"],
help=("type of feed forward layers in transformer block"),
)
parser.add_argument(
"--d-ff",
default=2048,
type=int,
help=("size middle layer in feed forward block"),
)
parser.add_argument(
"--ff-kernel-size",
default=3,
type=int,
help=("kernel size in convolutional feed forward block"),
)
try:
parser.add_argument("--hid-act", default="relu6", help="hidden activation")
except:
pass
parser.add_argument(
"--pos-dropout-rate",
default=0.1,
type=float,
help="positional encoder dropout",
)
parser.add_argument(
"--att-dropout-rate", default=0, type=float, help="self-att dropout"
)
parser.add_argument(
"--ff-dropout-rate",
default=0.1,
type=float,
help="feed-forward layer dropout",
)
parser.add_argument(
"--in-layer-type",
default="linear",
choices=["linear", "conv2d-sub"],
help=("type of input layer"),
)
parser.add_argument(
"--rel-pos-enc",
default=False,
action="store_true",
help="use relative positional encoder",
)
parser.add_argument(
"--causal-pos-enc",
default=False,
action="store_true",
help="relative positional encodings are zero when attending to the future",
)
parser.add_argument(
"--concat-after",
default=False,
action="store_true",
help="concatenate attention input and output instead of adding",
)
if prefix is not None:
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='transformer encoder options')
add_argparse_args = add_class_args