Source code for hyperion.torch.layers.pos_encoder

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
import math

import torch
from torch import nn


[docs]class PosEncoder(nn.Module): """Positional encoding. Attributes: num_feats: embedding dim dropout_rate: dropout rate """
[docs] def __init__(self, num_feats, dropout_rate=0): super().__init__() self.num_feats = num_feats self.dropout_rate = dropout_rate self.xscale = math.sqrt(self.num_feats) if self.dropout_rate > 0: self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None
def __repr__(self): return self.__str__() def __str__(self): s = "{}(num_feats={}, dropout_rate={})".format( self.__class__.__name__, self.num_feats, self.dropout_rate ) return s
[docs] def _pe(self, x, relative=False): """Reset the positional encodings.""" if self.pe is not None: if self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return self.pe pe = torch.zeros(x.size(1), self.num_feats) if relative: # this is for relative positional encoders position = torch.arange( x.size(1) - 1, -1, -1, dtype=torch.float32 ).unsqueeze(1) else: position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.num_feats, 2, dtype=torch.float32) * -(math.log(10000.0) / self.num_feats) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.pe = pe.to(device=x.device, dtype=x.dtype) return self.pe
[docs] def forward(self, x): """Add positional encoding. Args: x: Input with shape=(batch, time, C) Returns: x-scaled + pos-encoder """ pe = self._pe(x) x = x * self.xscale + pe[:, : x.size(1)] if self.dropout_rate > 0: return self.dropout(x) return x
[docs]class RelPosEncoder(PosEncoder): """Relative Positional encoding as defined in https://arxiv.org/pdf/1901.02860.pdf It returns the input and the positional encoder separtely so they are mixed in the attention block later. Attributes: num_feats: embedding dim dropout_rate: dropout rate """
[docs] def __init__(self, num_feats, dropout_rate=0): super().__init__(num_feats, dropout_rate)
[docs] def forward(self, x): """Add positional encoding. Args: x: Input with shape=(batch, time, C) Returns: x-scaled, pos-encoding """ pe = self._pe(x, relative=True) x = x * self.xscale # we want embedding [R_L,..., R_0] # while in non relative we want [R_0, ..., R_L] pos_emb = self.pe[:, -x.size(1) :] # this pos_emb is matrix Q in # https://arxiv.org/pdf/1901.02860.pdf Appendix B # I think it should have been denoted as R, # probably a typo in the paper if self.dropout_rate > 0: x = self.dropout(x) pos_emb = self.dropout(pos_emb) return x, pos_emb
[docs]class NoPosEncoder(nn.Module): """This is a dummy class for the case where we deactivate the positional encoder """
[docs] def __init__(self): super().__init__()
[docs] def forward(self, x): """Identity map Args: x: Input with shape=(batch, time, C) Returns: x """ return x