Source code for hyperion.torch.layer_blocks.conformer_encoder_v1

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
#

import torch
import torch.nn as nn

from ..layers.attention import *
from .transformer_feedforward import *
from .conformer_conv import ConformerConvBlock


[docs]class ConformerEncoderBlockV1(nn.Module): """Building block for conformer encoder introduced in https://arxiv.org/pdf/2005.08100.pdf This includes some optional extra features not included in the original paper: - Choose local-attention (attending only to close frames instead of all the frames in the sequence) - Choose number of conv blocks - Squeeze-Excitation after depthwise-conv - Allows downsampling in time dimension - Allows choosing activation and layer normalization type We call this Conformer+ Attributes: num_feats: input/output feat. dimension (aka d_model) self_attn: attention module in ['scaled-dot-prod-att-v1', 'local-scaled-dot-prod-att-v1'] num_heads: number of heads conv_repeats: number of conv blocks conv_kernel_size: kernel size for conv blocks conv_stride: stride for depth-wise conv in first conv block feed_forward: position-wise feed-forward string in ['linear', 'conv1dx2', 'conv1d-linear'] d_ff: dimension of middle layer in feed_forward block ff_kernel_size: kernel size for convolutional versions of ff block hid_act: ff and conv block hidden activation dropout_rate: dropout rate for ff and conv blocks att_context: maximum context range for local attention att_dropout_rate: dropout rate for attention block causal_pos_enc: if True, use causal positional encodings (when rel_pos_enc=True), it assumes that query q_i only attents to key k_j when j<=i conv_norm_layer: norm layer constructor for conv block, if None it uses BatchNorm se_r: Squeeze-Excitation compression ratio, if None it doesn't use Squeeze-Excitation ff_macaron: if True, it uses macaron-net style ff layers, otherwise transformer style. out_lnorm: if True, use LNorm layer at the output as in the conformer paper, we think that this layer is redundant and put it to False by default concat_after: if True, if concats attention input and output and apply linear transform, i.e., y = x + linear(concat(x, att(x))) if False, y = x + att(x) """
[docs] def __init__( self, num_feats, self_attn, num_heads, conv_repeats=1, conv_kernel_size=31, conv_stride=1, feed_forward="linear", d_ff=2048, ff_kernel_size=3, hid_act="swish", dropout_rate=0, att_context=25, att_dropout_rate=0, pos_enc_type="rel", causal_pos_enc=False, conv_norm_layer=None, se_r=None, ff_macaron=True, out_lnorm=False, concat_after=False, ): super().__init__() self.self_attn = self._make_att( self_attn, num_feats, num_heads, att_context, att_dropout_rate, pos_enc_type, causal_pos_enc, ) self.ff_scale = 1 self.ff_macaron = ff_macaron if ff_macaron: self.ff_scale = 0.5 self.feed_forward_macaron = self._make_ff( feed_forward, num_feats, d_ff, ff_kernel_size, hid_act, dropout_rate ) self.norm_ff_macaron = nn.LayerNorm(num_feats) self.feed_forward = self._make_ff( feed_forward, num_feats, d_ff, ff_kernel_size, hid_act, dropout_rate ) conv_blocks = [] for i in range(conv_repeats): block_i = ConformerConvBlock( num_feats, conv_kernel_size, conv_stride, activation=hid_act, norm_layer=conv_norm_layer, dropout_rate=dropout_rate, se_r=se_r, ) conv_stride = 1 conv_blocks.append(block_i) self.conv_blocks = nn.ModuleList(conv_blocks) self.norm_att = nn.LayerNorm(num_feats) self.norm_ff = nn.LayerNorm(num_feats) self.out_lnorm = out_lnorm if out_lnorm: self.norm_out = nn.LayerNorm(num_feats) self.dropout_rate = dropout_rate if self.dropout_rate > 0: self.dropout = nn.Dropout(self.dropout_rate) self.concat_after = concat_after if self.concat_after: self.concat_linear = nn.Linear(num_feats + num_feats, num_feats)
[docs] @staticmethod def _make_att( att_type, num_feats, num_heads, context, dropout_rate, pos_enc_type, causal_pos_enc, ): """Creates multihead attention block from att_type string Args: att_type: string in ['scaled-dot-prod-att-v1', 'local-scaled-dot-prod-att-v1'] num_feats: input/output feat. dimension (aka d_model) num_heads: number of heads dropout_rate: dropout rate for attention block pos_enc_type: type of positional encoder causal_pos_enc: if True, use causal positional encodings (when rel_pos_enc=True), it assumes that query q_i only attents to key k_j when j<=i Returns: Attention nn.Module """ assert num_feats % num_heads == 0 d_k = num_feats // num_heads if att_type == "scaled-dot-prod-v1": if pos_enc_type == "rel": return ScaledDotProdAttRelPosEncV1( num_feats, num_feats, num_heads, d_k, d_k, causal_pos_enc, dropout_rate, time_dim=1, ) return ScaledDotProdAttV1( num_feats, num_feats, num_heads, d_k, d_k, dropout_rate, time_dim=1 ) if att_type == "local-scaled-dot-prod-v1": if pos_enc_type == "rel": return LocalScaledDotProdAttRelPosEncV1( num_feats, num_feats, num_heads, d_k, d_k, context, causal_pos_enc, dropout_rate, time_dim=1, ) return LocalScaledDotProdAttV1( num_feats, num_feats, num_heads, d_k, d_k, context, dropout_rate, time_dim=1, )
[docs] @staticmethod def _make_ff(ff_type, num_feats, hid_feats, kernel_size, activation, dropout_rate): """Creates position-wise feed forward block from ff_type string Args: ff_type: string in ['linear', 'conv1dx2', 'conv1d-linear'] num_feats: input/output feat. dimension (aka d_model) hid_feats: dimension of middle layer in feed_forward block kernel_size: kernel size for convolutional versions of ff block dropout_rate: dropout rate for ff block activation: activation function for ff block Returns: Position-wise feed-forward nn.Module """ if ff_type == "linear": return PositionwiseFeedForward( num_feats, hid_feats, activation, dropout_rate, time_dim=1 ) if ff_type == "conv1dx2": return Conv1dx2( num_feats, hid_feats, kernel_size, activation, dropout_rate, time_dim=1 ) if ff_type == "conv1d-linear": return Conv1dLinear( num_feats, hid_feats, kernel_size, activation, dropout_rate, time_dim=1 )
[docs] def forward(self, x, pos_emb=None, mask=None): """Forward pass function Args: x: input tensor with size=(batch, time, num_feats) pos_emb: positional embedding size=(batch, time2, in_feats) as R_{L-1}, ..., R_0, when using relative postional encoder, otherwise None mask: mask to indicate valid time steps for x (batch, time) Returns: Tensor with output features Tensor with mask """ # macaron feed forward if self.ff_macaron: residual = x x = self.norm_ff_macaron(x) x = self.feed_forward_macaron(x) if self.dropout_rate > 0: x = self.dropout(x) x = residual + self.ff_scale * x # multihead attention residual = x x = self.norm_att(x) if pos_emb is None: x_att = self.self_attn(x, x, x, mask=mask) else: x_att = self.self_attn(x, x, x, pos_emb=pos_emb, mask=mask) if self.concat_after: x = torch.cat((x, x_att), dim=-1) x = self.concat_linear(x) else: x = x_att if self.dropout_rate > 0: x = self.dropout(x) x = residual + x # convolutional blocks x = x.transpose(1, 2) for block in range(len(self.conv_blocks)): x = self.conv_blocks[block](x) x = x.transpose(1, 2) # feed-forward block residual = x x = self.norm_ff(x) x = self.feed_forward(x) if self.dropout_rate > 0: x = self.dropout(x) x = residual + self.ff_scale * x # output norm if self.out_lnorm: x = self.norm_out(x) return x, mask