Source code for hyperion.torch.layer_blocks.transformer_encoder_v1

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import torch
import torch.nn as nn

from ..layers.attention import *
from .transformer_feedforward import *


[docs]class TransformerEncoderBlockV1(nn.Module): """Building block for transformer encoder. Attributes: num_feats: input/output feat. dimension (aka d_model) self_attn: attention nn.Module or string in ['scaled-dot-prod-att-v1', 'local-scaled-dot-prod-att-v1'] num_heads: number of heads feed_forward: position-wise feed-forward nn.Module or string in ['linear', 'conv1dx2', 'conv1d-linear'] d_ff: dimension of middle layer in feed_forward block ff_kernel_size: kernel size for convolutional versions of ff block ff_act: ff block hidden activation ff_dropout_rate: dropout rate for ff block att_context: maximum context range for local attention att_dropout_rate: dropout rate for attention block rel_pos_enc: if True, use relative postional encodings, absolute encodings otherwise. causal_pos_enc: if True, use causal positional encodings (when rel_pos_enc=True), it assumes that query q_i only attents to key k_j when j<=i norm_before: if True, use layer norm before layers, otherwise after concat_after: if True, if concats attention input and output and apply linear transform, i.e., y = x + linear(concat(x, att(x))) if False, y = x + att(x) """
[docs] def __init__( self, num_feats, self_attn, num_heads, feed_forward, d_ff, ff_kernel_size, ff_act="relu6", ff_dropout_rate=0, att_context=25, att_dropout_rate=0, rel_pos_enc=False, causal_pos_enc=False, norm_before=True, concat_after=False, ): super().__init__() if isinstance(self_attn, str): self.self_attn = self._make_att( self_attn, num_feats, num_heads, att_context, att_dropout_rate, rel_pos_enc, causal_pos_enc, ) else: self.self_attn = self_attn if isinstance(feed_forward, str): self.feed_forward = self._make_ff( feed_forward, num_feats, d_ff, ff_kernel_size, ff_act, ff_dropout_rate ) else: self.feed_forward = feed_forward self.norm1 = nn.LayerNorm(num_feats) self.norm2 = nn.LayerNorm(num_feats) self.dropout_rate = ff_dropout_rate if self.dropout_rate > 0: self.dropout = nn.Dropout(self.dropout_rate) self.norm_before = norm_before self.concat_after = concat_after if self.concat_after: self.concat_linear = nn.Linear(num_feats + num_feats, num_feats)
[docs] @staticmethod def _make_att( att_type, num_feats, num_heads, context, dropout_rate, rel_pos_enc, causal_pos_enc, ): """Creates multihead attention block from att_type string Args: att_type: string in ['scaled-dot-prod-att-v1', 'local-scaled-dot-prod-att-v1'] num_feats: input/output feat. dimension (aka d_model) num_heads: number of heads dropout_rate: dropout rate for attention block rel_pos_enc: if True, use relative postional encodings, absolute encodings otherwise. causal_pos_enc: if True, use causal positional encodings (when rel_pos_enc=True), it assumes that query q_i only attents to key k_j when j<=i Returns: Attention nn.Module """ assert num_feats % num_heads == 0 d_k = num_feats // num_heads if att_type == "scaled-dot-prod-v1": if rel_pos_enc: return ScaledDotProdAttRelPosEncV1( num_feats, num_feats, num_heads, d_k, d_k, causal_pos_enc, dropout_rate, time_dim=1, ) return ScaledDotProdAttV1( num_feats, num_feats, num_heads, d_k, d_k, dropout_rate, time_dim=1 ) if att_type == "local-scaled-dot-prod-v1": if rel_pos_enc: return LocalScaledDotProdAttRelPosEncV1( num_feats, num_feats, num_heads, d_k, d_k, context, causal_pos_enc, dropout_rate, time_dim=1, ) return LocalScaledDotProdAttV1( num_feats, num_feats, num_heads, d_k, d_k, context, dropout_rate, time_dim=1, )
[docs] @staticmethod def _make_ff(ff_type, num_feats, hid_feats, kernel_size, activation, dropout_rate): """Creates position-wise feed forward block from ff_type string Args: ff_type: string in ['linear', 'conv1dx2', 'conv1d-linear'] num_feats: input/output feat. dimension (aka d_model) hid_feats: dimension of middle layer in feed_forward block kernel_size: kernel size for convolutional versions of ff block dropout_rate: dropout rate for ff block activation: activation function for ff block Returns: Position-wise feed-forward nn.Module """ if ff_type == "linear": return PositionwiseFeedForward( num_feats, hid_feats, activation, dropout_rate, time_dim=1 ) if ff_type == "conv1dx2": return Conv1dx2( num_feats, hid_feats, kernel_size, activation, dropout_rate, time_dim=1 ) if ff_type == "conv1d-linear": return Conv1dLinear( num_feats, hid_feats, kernel_size, activation, dropout_rate, time_dim=1 )
[docs] def forward(self, x, pos_emb=None, mask=None): """Forward pass function Args: x: input tensor with size=(batch, time, num_feats) pos_emb: positional embedding size=(batch, time2, in_feats) as R_{L-1}, ..., R_0, when using relative postional encoder, otherwise None mask: mask to indicate valid time steps for x (batch, time) Returns: Tensor with output features Tensor with mask """ residual = x if self.norm_before: x = self.norm1(x) if pos_emb is None: x_att = self.self_attn(x, x, x, mask=mask) else: x_att = self.self_attn(x, x, x, pos_emb=pos_emb, mask=mask) if self.concat_after: x = torch.cat((x, x_att), dim=-1) x = self.concat_linear(x) else: x = x_att if self.dropout_rate > 0: x = self.dropout(x) x = residual + x if not self.norm_before: x = self.norm1(x) residual = x if self.norm_before: x = self.norm2(x) x = self.feed_forward(x) if self.dropout_rate > 0: x = self.dropout(x) x = residual + x if not self.norm_before: x = self.norm2(x) return x, mask