Source code for hyperion.torch.layers.attention

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import math

# import numpy
import torch
from torch import nn


[docs]class ScaledDotProdAttV1(nn.Module): """Scaled dot product multihead attention layer Attributes: in_feats: input feature dimension out_feats: output feature dimension num_heads: number of heads d_k: key/query projection dimension d_v: value projection dimension dropout_rate: dropout rate time_dim: time dimension in the input, default=1 meaning input dimensions are (batch, time, in_feats) """
[docs] def __init__( self, in_feats, out_feats, num_heads, d_k, d_v, dropout_rate=0, time_dim=1 ): super().__init__() # We assume d_v always equals d_k self.d_v = d_v self.d_k = d_k self.num_heads = num_heads self.dropout_rate = dropout_rate self.time_dim = time_dim self.linear_q = nn.Linear(in_feats, num_heads * d_k) self.linear_k = nn.Linear(in_feats, num_heads * d_k) self.linear_v = nn.Linear(in_feats, num_heads * d_v) self.linear_out = nn.Linear(num_heads * d_v, out_feats) self.attn = None if self.dropout_rate > 0: self.dropout = nn.Dropout(p=dropout_rate)
@property def in_feats(self): return self.linear_v.in_features @property def out_feats(self): return self.linear_out.out_features def __repr__(self): return self.__str__() def __str__(self): s = "{}(in_feats={}, out_feats={}, num_heads={}, d_k={}, d_v={}, dropout_rate={}, time_dim={})".format( self.__class__.__name__, self.in_feats, self.out_feats, self.num_heads, self.d_k, self.d_v, self.dropout_rate, self.time_dim, ) return s def _compute_qkv(self, query, key, value): batch_size = value.size(0) if self.time_dim != 1: query = query.transpose(1, self.time_dim) key = key.transpose(1, self.time_dim) value = value.transpose(1, self.time_dim) q = self.linear_q(query).view(batch_size, -1, self.num_heads, self.d_k) k = self.linear_k(key).view(batch_size, -1, self.num_heads, self.d_k) v = self.linear_v(value).view(batch_size, -1, self.num_heads, self.d_v) q = q.transpose(1, 2) # (batch, head, time1, d_k) k = k.transpose(1, 2) # (batch, head, time2, d_k) v = v.transpose(1, 2) # (batch, head, time2, d_v) return q, k, v def _compute_softmax(self, scores, mask): if mask is not None: mask = mask.unsqueeze(1).eq( 0 ) # (batch, 1, time1, time2) or (batch, 1, time) if scores.dtype == torch.half: min_value = -65504 else: min_value = -1e20 if mask.dim() == 4: scores = scores.masked_fill(mask, min_value) return torch.softmax(scores, dim=-1).masked_fill( mask, 0.0 ) # (batch, head, time1, time2) else: mask1 = mask.unsqueze(2) mask2 = mask.unsqueeze(-1) scores = scores.masked_fill(mask1, min_value) scores = scores.masked_fill(mask2, min_value) return torch.softmax(scores, dim=-1) # (batch, head, time1, time2) return torch.softmax(scores, dim=-1) # (batch, head, time1, time2) def _apply_attn(self, v): batch_size = v.size(0) if self.dropout_rate > 0: p_attn = self.dropout(self.attn) else: p_attn = self.attn x = torch.matmul(p_attn, v) # (batch, head, time1, d_k) x = ( x.transpose(1, 2) .contiguous() .view(batch_size, -1, self.num_heads * self.d_v) ) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) ___compute_softmax = _compute_softmax ___apply_attn = _apply_attn
[docs] def forward(self, query, key, value, mask=None): """Computes 'Scaled Dot Product Attention'. Args: query: query with size=(batch, time1, in_feats), where time1 is the output time dimension key: key with size=(batch, time2, in_feats) where time1 is the input time dimension value: value with size=(batch, time2, in_feats) mask: optional mask with size=(batch, time1, time2), to zero attention between some time steps or size=(batch, time) to make time1=time2 Returns: Attention weigthed average of the value with size=(batch, time1, out_feats) """ q, k, v = self._compute_qkv(query, key, value) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt( self.d_k ) # (batch, head, time1, time2) self.attn = self.___compute_softmax(scores, mask) return self.___apply_attn(v)
[docs]class LocalScaledDotProdAttV1(ScaledDotProdAttV1): """Local Scaled dot product multihead attention layer It calculates self-attention between time steps within a window of 'context' frames. Attributes: in_feats: input feature dimension out_feats: output feature dimension num_heads: number of heads d_k: key/query projection dimension d_v: value projection dimension context: maximum attention temporal context. dropout_rate: dropout rate time_dim: time dimension in the input, default=1 meaning input dimensions are (batch, time, in_feats) """
[docs] def __init__( self, in_feats, out_feats, num_heads, d_k, d_v, context=25, dropout_rate=0, time_dim=1, ): """Construct an MultiHeadedAttention object.""" super().__init__( in_feats, out_feats, num_heads, d_k, d_v, dropout_rate, time_dim ) self.context = context
def __repr__(self): return self.__str__() def __str__(self): s = ( "{}(in_feats={}, out_feats={}, num_heads={}, d_k={}, d_v={}, " "context={}, dropout_rate={}, time_dim={})".format( self.__class__.__name__, self.in_feats, self.out_feats, self.num_heads, self.d_k, self.d_v, self.context, self.dropout_rate, self.time_dim, ) ) return s def _compute_qkv00(self, query, key, value): batch_size = query.size(0) t1 = query.size(self.time_dim) t2 = key.size(self.time_dim) if self.time_dim != 1: query = query.transpose(1, self.time_dim) key = key.transpose(1, self.time_dim) value = value.transpose(1, self.time_dim) context_k = self.context num_blocks = math.ceil(t2 / context_k) # (t2 + context_k//2)//context_k context_q = math.ceil(t1 / num_blocks) num_blocks_q = math.ceil(t1 / context_q) # (t1 + context_q//2)//context_q assert ( num_blocks == num_blocks_q ), "num_blocks_k({})!=num_blocks_q({}), context_k={}, context_q={}, t1={}, t2={}".format( num_blocks, num_blocks_q, context_k, context_q, t1, t2 ) pad1 = context_q * num_blocks - t1 pad2 = context_k * num_blocks - t2 # print('1',query.shape,key.shape,value.shape,pad1,pad2, context_q, context_k) if pad1 > 0: query = nn.functional.pad(query, (0, 0, 0, pad1)) if pad2 > 0: key = nn.functional.pad(key, (0, 0, 0, pad2)) value = nn.functional.pad(value, (0, 0, 0, pad2)) # print('2',query.shape,key.shape,value.shape) q0 = self.linear_q(query) # (batch, time1, head*d_k) k0 = self.linear_k(key) # (batch, time2, head*d_k) v0 = self.linear_v(value) # (batch, time2, head*d_v) return q0, k0, v0, context_q, context_k, num_blocks def _compute_qkv0(self, query, key, value): batch_size = query.size(0) t1 = query.size(self.time_dim) t2 = key.size(self.time_dim) if self.time_dim != 1: query = query.transpose(1, self.time_dim) key = key.transpose(1, self.time_dim) value = value.transpose(1, self.time_dim) num_blocks = round(t2 / self.context) # print(num_blocks, t2, self.context) context_k = math.ceil(t2 / num_blocks) context_q = math.ceil(t1 / num_blocks) pad1 = context_q * num_blocks - t1 pad2 = context_k * num_blocks - t2 # print('1',query.shape,key.shape,value.shape,pad1,pad2, context_q, context_k) if pad1 > 0: query = nn.functional.pad(query, (0, 0, 0, pad1)) if pad2 > 0: key = nn.functional.pad(key, (0, 0, 0, pad2)) value = nn.functional.pad(value, (0, 0, 0, pad2)) # print('2',query.shape,key.shape,value.shape) q0 = self.linear_q(query) # (batch, time1, head*d_k) k0 = self.linear_k(key) # (batch, time2, head*d_k) v0 = self.linear_v(value) # (batch, time2, head*d_v) return q0, k0, v0, context_q, context_k, num_blocks def _compute_scores( self, q0, k0, num_blocks, context_q, context_k, q_left_shift, k_left_shift ): batch_size = q0.size(0) if q_left_shift > 0: # we are computing the shifted block-diag score matrix q_right_shift = context_q - q_left_shift k_right_shift = context_k - k_left_shift q0 = q0[:, q_left_shift:-q_right_shift] k0 = k0[:, k_left_shift:-k_right_shift] q = ( q0.view(batch_size, -1, self.num_heads, self.d_k) .transpose(1, 2) .contiguous() .view(batch_size, self.num_heads, num_blocks, -1, self.d_k) ) # (batch, head, blocks, time1, d_k) k = ( k0.view(batch_size, -1, self.num_heads, self.d_k) .transpose(1, 2) .contiguous() .view(batch_size, self.num_heads, num_blocks, -1, self.d_k) ) # (batch, head, blocks time2, d_k) # print('4',q.shape,k.shape) return torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
[docs] @staticmethod def _softmax(scores1, scores2, shift1, shift2, t1, t2): """Computes softmax for block diagonal attention maps Args: scores1: attention scores from block-diagonal score matrix with size=(batch, heads, blocks, t1, t2) scores2: attention scores from a shifted block-diagonal score matrix with size=(batch, heads, blocks-1, t1, t2) shift1: shift of diagonal blocks of scores2 wrt scores1 in time steps in the time dimension 1 shift2: shift of diagonal blocks of scores2 wrt scores1 in time steps in the time dimension 2, with self-attention shift1=shift2 t1: length of time dimension 1 (output time dimension) t2: length of time dimension 2 (input time dimension), with self-att t1=t2. Returns probs1: posterior attention scores for block-diagonal att. matrix with size=(batch, heads, blocks, t1, t2) probs2: posterior attention scores for a shifted block-diagonal att. matrix with size=(batch, heads, blocks-1, t1, t2) """ if scores2.dtype == torch.half: min_val = -65504 else: min_val = -1e20 batch_size = scores1.size(0) num_heads = scores1.size(1) num_blocks = scores1.size(2) context1 = scores1.size(3) context2 = scores1.size(4) # set elements in scores2 that overlap with elements in scores1 to -inf scores2[:, :, :, : context1 - shift1, : context2 - shift2] = min_val scores2[:, :, :, shift1:, shift2:] = min_val # set the padding time steps that we had to add to make integer block-number to -inf # in scores1 # print('softmax', scores1.shape, scores2.shape, shift1, shift2, t1, t2, # scores1.size(2)*scores1.size(3) - t1, scores2.size(2)*scores2.size(3) + shift1 - t1, # scores1.size(2)*scores1.size(4) - t2, scores2.size(2)*scores2.size(4) + shift2 - t2) dt1 = max(0, scores1.size(2) * scores1.size(3) - t1) if dt1 > 0: scores1[:, :, -1, -dt1:, :] = min_val dt1 = max(0, scores2.size(2) * scores2.size(3) + shift1 - t1) # in scores2 if dt1 > 0: scores2[:, :, -1, -dt1:, :] = min_val dt2 = max(0, scores1.size(2) * scores1.size(4) - t2) if dt2 > 0: scores1[:, :, -1, :, -dt2:] = min_val dt2 = max(0, scores2.size(2) * scores2.size(4) + shift2 - t2) # in scores2 if dt2 > 0: scores2[:, :, -1, :, -dt2:] = min_val # dt1 = max(0, scores1.size(2)*scores1.size(3) - t1) # dt2 = max(0, scores1.size(2)*scores1.size(4) - t2) # if dt1 > 0 or dt2 > 0: # scores1[:,:,-1,-dt1:,-dt2:] = min_val # # in scores2 # dt1 = max(0, dt1 - shift1) # dt2 = max(0, dt2 - shift2) # if dt1 > 0 or dt2 > 0: # scores2[:,:,-1,-dt1:,-dt2:] = min_val # flatten blocks and time1 dimensions scores1 = scores1.view(batch_size, num_heads, -1, context2) scores2 = scores2.view(batch_size, num_heads, -1, context2) # print('aa', scores1.shape, scores2.shape) # pad scores2 to have the same size as scores1 scores2 = nn.functional.pad( scores2, (0, 0, shift1, context1 - shift1), mode="constant", value=min_val ) # print('bb', scores1.shape, scores2.shape) # concat scores1, scores2 and do softmax in time2 dimension # (batch, heads, blocks*time1, 2*time2) probs = torch.softmax(torch.cat((scores1, scores2), dim=-1), dim=-1) # now we separate back probs into probs1, and probs2 # probs1 probs1 = ( probs[:, :, :, :context2] .contiguous() .view(batch_size, num_heads, num_blocks, -1, context2) ) # probs2 probs2 = ( probs[:, :, shift1 : -(context1 - shift1), context2:] .contiguous() .view(batch_size, num_heads, num_blocks - 1, -1, context2) ) return probs1, probs2
def _mask_scores_1d(self, scores, mask, shift1, shift2): if scores.dtype == torch.half: min_value = -65504 else: min_value = -1e20 batch_size = scores.size(0) num_blocks = scores.size(2) context1 = scores.size(3) context2 = scores.size(4) mask_blocks = torch.ones_like(scores, dtype=mask.dtype) mask_single_block = torch.zeros( (batch_size, context1, context2), dtype=mask.dtype ) t1_start = shift1 t2_start = shift2 for block in range(num_blocks): t1_end = t1_start + context1 t2_end = t2_start + context2 mask_single_block.fill_(False) mask_single_block.masked_fill_(mask[:, 0, t1_start:t1_end], True) mask_single_block.masked_fill_(mask[:, :, t2_start:t2_end], True) mask_blocks[:, block] = mask_single_block t1_start += context1 t2_start += context2 return scores.masked_fill(mask_blocks, min_value) def _mask_scores_2d(self, scores, mask, shift1, shift2): if scores.dtype == torch.half: min_value = -65504 else: min_value = -1e20 batch_size = scores.size(0) num_blocks = scores.size(2) context1 = scores.size(3) context2 = scores.size(4) mask_blocks = torch.ones_like(scores, dtype=mask.dtype) t1_start = shift1 t2_start = shift2 for block in range(num_blocks): t1_end = min(t1_start + context1, mask.size(1)) t2_end = min(t2_start + context2, mask.size(2)) mask_blocks[:, block, : (t1_end - t1_start), : (t2_end - t2_start)] = mask[ :, t1_start:t1_end, t2_start:t2_end ] t1_start += context1 t2_start += context2 return scores.masked_fill(mask_blocks, min_value) def _compute_softmax( self, scores1, scores2, mask, q_left_shift, k_left_shift, t1, t2 ): if mask is not None: # put to -inf scores in points where mask==0 if mask.dim() == 4: # case when mask is 2d matrix per batch element mask = mask.eq(0) # (batch, time1, time2) # first, we mask block diagonal blocks scores1 = self._mask_scores_2d(scores1, mask, 0, 0) # second, we mask shifted block diagonal blocks scores2 = self._mask_scores_2d( scores2, mask, q_left_shift, k_left_shift ) else: # case when mask is 1d vector per batch element, # meaning that time1 and time2 are the same, so mask is symmetric mask = nn.functional.pad(mask, (0, pad2)) mask = mask.squeeze(1).eq(0) # (batch, 1, time) # first, we mask block diagonal blocks scores1 = self._mask_scores_1d(scores1, mask, 0, 0) # second, we mask shifted block diagonal blocks scores2 = self._mask_scores_1d( scores2, mask, q_left_shift, k_left_shift ) self.attn1, self.attn2 = self._softmax( scores1, scores2, q_left_shift, k_left_shift, t1, t2 ) def _apply_attn(self, v0, t1): if self.dropout_rate > 0: p_attn1 = self.dropout(self.attn1) p_attn2 = self.dropout(self.attn2) else: p_attn1 = self.attn1 p_attn2 = self.attn2 batch_size = p_attn1.size(0) num_blocks = p_attn1.size(2) context_q = p_attn1.size(3) context_k = p_attn1.size(4) q_left_shift = context_q // 2 k_left_shift = context_k // 2 q_right_shift = context_q - q_left_shift k_right_shift = context_k - k_left_shift v = ( v0.view(batch_size, -1, self.num_heads, self.d_v) .transpose(1, 2) .contiguous() .view(batch_size, self.num_heads, num_blocks, -1, self.d_k) ) # (batch, heads, blocks, time2, d_v) # print('8',p_attn1.shape,p_attn2.shape, v.shape) # (batch, head, blocks, time1, time2) x (batch, head, blocks, time2, d_v) x = torch.matmul(p_attn1, v) # (batch, heads, blocks, time1, d_k) # print('9',x.shape) x = ( x.view(batch_size, self.num_heads, -1, self.d_k) .transpose(1, 2) .contiguous() .view(batch_size, -1, self.num_heads * self.d_v) ) # (batch, time1, d_model) # print('10',x.shape) v = ( v0[:, k_left_shift:-k_right_shift] .view(batch_size, -1, self.num_heads, self.d_v) .transpose(1, 2) .contiguous() .view(batch_size, self.num_heads, num_blocks - 1, -1, self.d_v) ) # (batch, blocks-1, head, time2, d_v) # print('11',p_attn1.shape,p_attn2.shape, v.shape) # (batch, blocks-1, head, time1, time2) x (batch, blocks-1, head, time2, d_v) x2 = torch.matmul(p_attn2, v) # (batch, heads, blocks-1, time1, d_k) # print('12',x2.shape) x2 = ( x2.view(batch_size, self.num_heads, -1, self.d_k) .transpose(1, 2) .contiguous() .view(batch_size, -1, self.num_heads * self.d_v) ) # (batch, time1, d_model) # print('12',x2.shape) x[:, q_left_shift:-q_right_shift:] = x[:, q_left_shift:-q_right_shift:] + x2 x = x[:, :t1] return self.linear_out(x) # (batch, time1, d_model)
[docs] def forward1(self, query, key, value, mask): """Computes 'Local Scaled Dot Product Attention'. Args: query: query with size=(batch, time1, in_feats), where time1 is the output time dimension key: key with size=(batch, time2, in_feats) where time1 is the input time dimension value: value with size=(batch, time2, in_feats) mask: optional mask with size=(batch, time1, time2), to zero attention between some time steps. or (batch, time) if time1=time2 Returns: Attention weigthed average of the values with size=(batch, time1, out_feats) """ batch_size = query.size(0) t1 = query.size(self.time_dim) t2 = key.size(self.time_dim) if t2 <= self.context: return super().forward(query, key, value, mask) q0, k0, v0, context_q, context_k, num_blocks = self._compute_qkv0( query, key, value ) # q0 size=(batch, time1, head * d_k) # k0 size=(batch, time2, head * d_k) # v0 size=(batch, time2, head * d_v) # compute block diagonal affinity matrix # # print('3',q0.shape,k0.shape,v0.shape) # q = q0.view( # batch_size, -1, self.num_heads, self.d_k).transpose( # 1, 2).contiguous().view( # batch_size, self.num_heads, num_blocks, -1, self.d_k) # # (batch, head, blocks, time1, d_k) # k = k0.view( # batch_size, -1, self.num_heads, self.d_k).transpose( # 1, 2).contiguous().view( # batch_size, self.num_heads, num_blocks, -1, self.d_k) # # (batch, head, blocks time2, d_k) # # print('4',q.shape,k.shape) # scores1 = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) scores1 = self._compute_scores(q0, k0, num_blocks, context_q, context_k, 0, 0) # (batch, head, blocks context_q, context_k) # print('5',scores1.shape) # compute shifted block diagonal affinity matrix q_left_shift = context_q // 2 k_left_shift = context_k // 2 # q_right_shift = context_q - q_left_shift # k_right_shift = context_k - k_left_shift # q = q0[:,q_left_shift:-q_right_shift].view( # batch_size, -1, self.num_heads, self.d_k).transpose( # 1, 2).contiguous().view( # batch_size, self.num_heads, num_blocks-1, -1, self.d_k) # # (batch, blocks-1, head, time1, d_k) # k = k0[:,k_left_shift:-k_right_shift].view( # batch_size, -1, self.num_heads, self.d_k).transpose( # 1, 2).contiguous().view( # batch_size, self.num_heads, num_blocks-1, -1, self.d_k) # # (batch, blocks-1, head, d_k) # # print('6',q.shape,k.shape) # scores2 = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) scores2 = self._compute_scores( q0, k0, num_blocks - 1, context_q, context_k, q_left_shift, k_left_shift ) # (batch, head, blocks-1 context_q, context_k) # print('7',scores2.shape) # combine both block diagonal affinity matrix to do the softmax # if mask is not None: # # put to -inf scores in points where mask==0 # if mask.dim() == 4: # # case when mask is 2d matrix per batch element # mask = mask.eq(0) # (batch, time1, time2) # # first, we mask block diagonal blocks # scores1 = self._mask_scores_2d(scores1, mask, 0, 0) # # second, we mask shifted block diagonal blocks # scores2 = self._mask_scores_2d(scores2, mask, q_left_shift, k_left_shift) # else: # # case when mask is 1d vector per batch element, # # meaning that time1 and time2 are the same, so mask is symmetric # mask = nn.functional.pad(mask, (0, pad2)) # mask = mask.squeeze(1).eq(0) # (batch, 1, time) # # first, we mask block diagonal blocks # scores1 = self._mask_scores_1d(scores1, mask, 0, 0) # # second, we mask shifted block diagonal blocks # scores2 = self._mask_scores_1d(scores2, mask, q_left_shift, k_left_shift) # self.attn1, self.attn2 = self._softmax( # scores1, scores2, q_left_shift, k_left_shift, t1, t2) self._compute_softmax( scores1, scores2, mask, q_left_shift, k_left_shift, t1, t2 ) return self._apply_attn(v0, t1)
# if self.dropout_rate > 0: # p_attn1 = self.dropout(self.attn1) # p_attn2 = self.dropout(self.attn2) # else: # p_attn1 = self.attn1 # p_attn2 = self.attn2 # v = v0.view( # batch_size, -1, self.num_heads, self.d_v).transpose( # 1, 2).contiguous().view( # batch_size, self.num_heads, num_blocks, -1, self.d_k) # # (batch, heads, blocks, time2, d_v) # # print('8',p_attn1.shape,p_attn2.shape, v.shape) # # (batch, blocks, head, time1, time2) x (batch, blocks, head, time2, d_v) # x = torch.matmul(p_attn1, v) # (batch, heads, blocks, time1, d_k) # # print('9',x.shape) # x = x.view(batch_size, self.num_heads, -1, self.d_k).transpose( # 1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_v) # # (batch, time1, d_model) # # print('10',x.shape) # v = v0[:,k_left_shift:-k_right_shift].view( # batch_size, -1, self.num_heads, self.d_v).transpose( # 1, 2).contiguous().view( # batch_size, self.num_heads, num_blocks-1, -1, self.d_v) # # (batch, blocks-1, head, time2, d_v) # # print('11',p_attn1.shape,p_attn2.shape, v.shape) # # (batch, blocks-1, head, time1, time2) x (batch, blocks-1, head, time2, d_v) # x2 = torch.matmul(p_attn2, v) # (batch, heads, blocks-1, time1, d_k) # # print('12',x2.shape) # x2 = x2.view(batch_size, self.num_heads, -1, self.d_k).transpose( # 1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_v) # # (batch, time1, d_model) # # print('12',x2.shape) # x[:,q_left_shift:-q_right_shift:] = x[:,q_left_shift:-q_right_shift:] + x2 # x = x[:,:t1] # return self.linear_out(x) # (batch, time1, d_model)
[docs] def forward2(self, query, key, value, mask): """Computes 'Local Scaled Dot Product Attention'. Args: query: query with size=(batch, time1, in_feats), where time1 is the output time dimension key: key with size=(batch, time2, in_feats) where time1 is the input time dimension value: value with size=(batch, time2, in_feats) mask: optional mask with size=(batch, time1, time2), to zero attention between some time steps. or (batch, time) if time1=time2 Returns: Attention weigthed average of the values with size=(batch, time1, out_feats) """ batch_size = query.size(0) t1 = query.size(self.time_dim) t2 = key.size(self.time_dim) if t2 <= self.context: return super().forward(query, key, value, mask) if self.time_dim != 1: query = query.transpose(1, self.time_dim) key = key.transpose(1, self.time_dim) value = value.transpose(1, self.time_dim) context_k = self.context num_blocks = math.ceil(t2 / context_k) # (t2 + context_k//2)//context_k context_q = math.ceil(t1 / num_blocks) num_blocks_q = math.ceil(t1 / context_q) # (t1 + context_q//2)//context_q assert ( num_blocks == num_blocks_q ), "num_blocks_k({})!=num_blocks_q({}), context_k={}, context_q={}, t1={}, t2={}".format( num_blocks, num_blocks_q, context_k, context_q, t1, t2 ) pad1 = context_q * num_blocks - t1 pad2 = context_k * num_blocks - t2 # print('1',query.shape,key.shape,value.shape,pad1,pad2, context_q, context_k) if pad1 > 0: query = nn.functional.pad(query, (0, 0, 0, pad1)) if pad2 > 0: key = nn.functional.pad(key, (0, 0, 0, pad2)) value = nn.functional.pad(value, (0, 0, 0, pad2)) # print('2',query.shape,key.shape,value.shape) q0 = self.linear_q(query) # (batch, time1, head*d_k) k0 = self.linear_k(key) # (batch, time2, head*d_k) v0 = self.linear_v(value) # (batch, time2, head*d_v) # # q0, k0, v0, context_q, context_k, num_blocks = self._compute_qkv0( # # query, key, value) # # # q0 size=(batch, time1, head*d_k) # # # k0 size=(batch, time2, head*d_k) # # # v0 size=(batch, time2, head*d_v) # compute block diagonal affinity matrix # # print('3',q0.shape,k0.shape,v0.shape) q = ( q0.view(batch_size, -1, self.num_heads, self.d_k) .transpose(1, 2) .contiguous() .view(batch_size, self.num_heads, num_blocks, -1, self.d_k) ) # (batch, head, blocks, time1, d_k) k = ( k0.view(batch_size, -1, self.num_heads, self.d_k) .transpose(1, 2) .contiguous() .view(batch_size, self.num_heads, num_blocks, -1, self.d_k) ) # (batch, head, blocks time2, d_k) # # print('4',q.shape,k.shape) scores1 = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # # scores1 = self._compute_scores( # # q0, k0, num_blocks, context_q, context_k, 0, 0) # (batch, head, blocks context_q, context_k) # print('5',scores1.shape) # compute shifted block diagonal affinity matrix q_left_shift = context_q // 2 k_left_shift = context_k // 2 q_right_shift = context_q - q_left_shift k_right_shift = context_k - k_left_shift q = ( q0[:, q_left_shift:-q_right_shift] .view(batch_size, -1, self.num_heads, self.d_k) .transpose(1, 2) .contiguous() .view(batch_size, self.num_heads, num_blocks - 1, -1, self.d_k) ) # (batch, blocks-1, head, time1, d_k) k = ( k0[:, k_left_shift:-k_right_shift] .view(batch_size, -1, self.num_heads, self.d_k) .transpose(1, 2) .contiguous() .view(batch_size, self.num_heads, num_blocks - 1, -1, self.d_k) ) # # (batch, blocks-1, head, d_k) # # print('6',q.shape,k.shape) scores2 = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # scores2 = self._compute_scores( # q0, k0, num_blocks-1, context_q, context_k, # q_left_shift, k_left_shift) # (batch, head, blocks-1 context_q, context_k) # print('7',scores2.shape) # combine both block diagonal affinity matrix to do the softmax # if mask is not None: # # put to -inf scores in points where mask==0 # if mask.dim() == 4: # # case when mask is 2d matrix per batch element # mask = mask.eq(0) # (batch, time1, time2) # # first, we mask block diagonal blocks # scores1 = self._mask_scores_2d(scores1, mask, 0, 0) # # second, we mask shifted block diagonal blocks # scores2 = self._mask_scores_2d(scores2, mask, q_left_shift, k_left_shift) # else: # # case when mask is 1d vector per batch element, # # meaning that time1 and time2 are the same, so mask is symmetric # mask = nn.functional.pad(mask, (0, pad2)) # mask = mask.squeeze(1).eq(0) # (batch, 1, time) # # first, we mask block diagonal blocks # scores1 = self._mask_scores_1d(scores1, mask, 0, 0) # # second, we mask shifted block diagonal blocks # scores2 = self._mask_scores_1d(scores2, mask, q_left_shift, k_left_shift) self.attn1, self.attn2 = self._softmax( scores1, scores2, q_left_shift, k_left_shift, t1, t2 ) # # self._compute_softmax(scores1, scores2, mask, # # q_left_shift, k_left_shift, t1, t2) # # return self._apply_attn(v0, t1) if self.dropout_rate > 0: p_attn1 = self.dropout(self.attn1) p_attn2 = self.dropout(self.attn2) else: p_attn1 = self.attn1 p_attn2 = self.attn2 v = ( v0.view(batch_size, -1, self.num_heads, self.d_v) .transpose(1, 2) .contiguous() .view(batch_size, self.num_heads, num_blocks, -1, self.d_k) ) # (batch, heads, blocks, time2, d_v) # print('8',p_attn1.shape,p_attn2.shape, v.shape) # (batch, blocks, head, time1, time2) x (batch, blocks, head, time2, d_v) x = torch.matmul(p_attn1, v) # (batch, heads, blocks, time1, d_k) # print('9',x.shape) x = ( x.view(batch_size, self.num_heads, -1, self.d_k) .transpose(1, 2) .contiguous() .view(batch_size, -1, self.num_heads * self.d_v) ) # (batch, time1, d_model) # print('10',x.shape) v = ( v0[:, k_left_shift:-k_right_shift] .view(batch_size, -1, self.num_heads, self.d_v) .transpose(1, 2) .contiguous() .view(batch_size, self.num_heads, num_blocks - 1, -1, self.d_v) ) # (batch, blocks-1, head, time2, d_v) # print('11',p_attn1.shape,p_attn2.shape, v.shape) # (batch, blocks-1, head, time1, time2) x (batch, blocks-1, head, time2, d_v) x2 = torch.matmul(p_attn2, v) # (batch, heads, blocks-1, time1, d_k) # print('12',x2.shape) x2 = ( x2.view(batch_size, self.num_heads, -1, self.d_k) .transpose(1, 2) .contiguous() .view(batch_size, -1, self.num_heads * self.d_v) ) # (batch, time1, d_model) # print('12',x2.shape) x[:, q_left_shift:-q_right_shift:] = x[:, q_left_shift:-q_right_shift:] + x2 x = x[:, :t1] return self.linear_out(x) # (batch, time1, d_model)
[docs] def forward(self, query, key, value, mask): """Computes 'Local Scaled Dot Product Attention'. Args: query: query with size=(batch, time1, in_feats), where time1 is the output time dimension key: key with size=(batch, time2, in_feats) where time1 is the input time dimension value: value with size=(batch, time2, in_feats) mask: optional mask with size=(batch, time1, time2), to zero attention between some time steps. or (batch, time) if time1=time2 Returns: Attention weigthed average of the values with size=(batch, time1, out_feats) """ batch_size = query.size(0) t1 = query.size(self.time_dim) t2 = key.size(self.time_dim) if t2 <= 2 * self.context: return super().forward(query, key, value, mask) q0, k0, v0, context_q, context_k, num_blocks = self._compute_qkv0( query, key, value ) # q0 size=(batch, time1, head*d_k) # k0 size=(batch, time2, head*d_k) # v0 size=(batch, time2, head*d_v) # compute block diagonal affinity matrix scores1 = self._compute_scores(q0, k0, num_blocks, context_q, context_k, 0, 0) # (batch, head, blocks context_q, context_k) # compute shifted block diagonal affinity matrix q_left_shift = context_q // 2 k_left_shift = context_k // 2 scores2 = self._compute_scores( q0, k0, num_blocks - 1, context_q, context_k, q_left_shift, k_left_shift ) # (batch, head, blocks-1 context_q, context_k) # combine both block diagonal affinity matrix to do the softmax self._compute_softmax( scores1, scores2, mask, q_left_shift, k_left_shift, t1, t2 ) return self._apply_attn(v0, t1)
[docs]class ScaledDotProdAttRelPosEncV1(ScaledDotProdAttV1): """Scaled dot product multihead attention layer with relative positional encoders as defined in https://arxiv.org/pdf/1901.02860.pdf Attributes: in_feats: input feature dimension out_feats: output feature dimension num_heads: number of heads d_k: key/query projection dimension d_v: value projection dimension causal_pos_enc: positional encoder is 0 for attending future frames. dropout_rate: dropout rate time_dim: time dimension in the input, default=1 meaning input dimensions are (batch, time, in_feats) """
[docs] def __init__( self, in_feats, out_feats, num_heads, d_k, d_v, causal_pos_enc=False, dropout_rate=0, time_dim=1, ): super().__init__( in_feats, out_feats, num_heads, d_k, d_v, dropout_rate=dropout_rate, time_dim=time_dim, ) self.linear_pos = nn.Linear(in_feats, num_heads * d_k) # u, v in paper, Sec 3.3, 2nd eq. self.u = nn.Parameter(torch.Tensor(num_heads, d_k)) self.v = nn.Parameter(torch.Tensor(num_heads, d_k)) # we use same init as in espnet nn.init.xavier_uniform_(self.u) nn.init.xavier_uniform_(self.v) self.causal_pos_enc = causal_pos_enc self._tril = None self._tril_diag = 0 self._triu = None self._triu_diag = 0
[docs] def _apply_tril(self, x): """Applies lower triangular mask to (Q + v^T) W R_{i-j} attention matrix to keep causal attention points, i.e., i-j >= 0 E.g., if t1=3, t2=4 this will apply a mask [1 1 0 0; 1 1 1 0; 1 1 1 1 ] """ diag = x.size(3) - x.size(2) if ( self._tril is None or self._tril.size(2) < x.size(2) or self._tril.size(3) < x.size(3) or self._tril_diag != diag ): # in these cases we need to recompute the lower triangular mask ones = torch.ones((x.size(2), x.size(3)), dtype=x.dtype, device=x.device) self._tril = torch.tril(ones, diag)[None, None, :, :] self._tril_diag = diag tril = self._tril else: tril = self._tril[:, :, : x.size(2), : x.size(3)] return x * tril
[docs] def _apply_triu(self, x): """Applies upper triangular mask to (Q + v^T) W R_{i-j} attention matrix to keep non-causal attention points, i.e., i-j < 0 E.g., if t1=3, t2=4 this will apply a mask [0 0 1 1; 0 0 0 1; 0 0 0 0 ] """ # we add 1 to put the diagonal to 0 so we don't count the R_0 embedding twice diag = x.size(3) - x.size(2) + 1 if ( self._triu is None or self._triu.size(2) < x.size(2) or self._triu.size(3) < x.size(3) or self._triu_diag != diag ): # in these cases we need to recompute the lower triangular mask ones = torch.ones((x.size(2), x.size(3)), dtype=x.dtype, device=x.device) self._triu = torch.triu(ones, diag)[None, None, :, :] self._triu_diag = diag triu = self._triu else: triu = self._triu[:, :, -x.size(2) :, -x.size(3) :] return x * triu
[docs] def _left_shift(self, x): """Applies left shifts to the rows of x to get scores with relative pos encodings R_{i-j} i-j >=0, causal attention E.g. [q0 R3, q0 R2, q0 R1, q0 R0; q1 R3, q1 R2, q1 R1, q1 R0; q2 R3, q2 R2, q2 R1, q2 R0] becomes: [q0 R1, q0 R0, 0 , 0 ; q1 R2, q1 R1, q1 R0, 0 ; q2 R3, q2 R2, q2 R1, q2 R0] """ x_pad = nn.functional.pad(x, (1, 0), mode="constant", value=0) x_pad = x_pad.view(*x.size()[:2], x.size(3) + 1, x.size(2)) x = x_pad[:, :, 1:].view_as(x) return self._apply_tril(x)
[docs] def _right_shift(self, x): """Applies right shifts to the rows of x to get scores with relative pos encodings R_{i-j} i-j < 0, non-causal attention E.g. [q0 R_0, q0 R_{-1}, q0 R_{-2}; q1 R_0, q1 R_{-1}, q1 R_{-2}; q2 R_0, q1 R_{-1}, q2 R_{-2}] becomes: [ 0, q0 R_{-1}, q0 R_{-2}; 0, 0 , q1 R_{-1}; 0, 0 , 0 ] """ x_pad = nn.functional.pad(x, (0, 1), mode="constant", value=0) x_pad = x_pad.view(*x.size()[:2], x.size(3) + 1, x.size(2)) x = x_pad[:, :, :-1].view_as(x) return self._apply_triu(x)
[docs] def forward(self, query, key, value, pos_emb=None, mask=None): """Computes 'Scaled Dot Product Attention'. Args: query: query with size=(batch, time1, in_feats), where time1 is the output time dimension key: key with size=(batch, time2, in_feats) where time1 is the input time dimension value: value with size=(batch, time2, in_feats) pos_emb: positional embedding size=(batch, time2, in_feats) as R_{L-1}, ..., R_0 mask: optional mask with size=(batch, time1, time2), to zero attention between some time steps or size=(batch, time) to make time1=time2 Returns: Attention weigthed average of the value with size=(batch, time1, out_feats) """ batch_size = value.size(0) q, k, v = self._compute_qkv(query, key, value) pos_batch_size = pos_emb.size(0) p = self.linear_pos(pos_emb).view(pos_batch_size, -1, self.num_heads, self.d_k) p = p.transpose(1, 2) # (batch, head, time2, d_k) q = q.transpose(1, 2) # (batch, time1, head, d_k) q_plus_u = (q + self.u).transpose(1, 2) # (batch, head, time1, d_k) q_plus_v = (q + self.v).transpose(1, 2) # (batch, head, time1, d_k) # compute A(a) + A(c) in Sec3.3, 2nd Eq. AC = torch.matmul(q_plus_u, k.transpose(-2, -1)) # (batch, head, time1, time2) # compute A(b) + A(d) in Sec3.3, 2nd Eq. for the causal part # This is the sum of Btilde and Dtilde in the Appendix of the paper BDtilde = torch.matmul( q_plus_v, p.transpose(-2, -1) ) # (batch, head, time1, time2) # apply left shift as indicated in the Appendix to geth B+D BD = self._left_shift(BDtilde) if not self.causal_pos_enc: # compute A(b) + A(d) for the non-causal part, # this is not included in the paper because it doesn't allow to attent to future postions # we assume that t2 >= t1 dt = key.size(1) - query.size(1) pos_emb_noncausal = pos_emb[:, dt:].flip( dims=(1,) ) # we flip to get R_0, ..., R_{L-1} pos_emb_noncausal[ :, :, 0::2 ] *= -1 # we multiply sin emb by -1 to get R_0, R_{-1}, ..., R_{-(L-1)} assert pos_emb[0, -2, 0] == -pos_emb_noncausal[0, 1, 0] p = self.linear_pos(pos_emb_noncausal).view( pos_batch_size, -1, self.num_heads, self.d_k ) p = p.transpose(1, 2) # (batch, head, time2-dt, d_k) BDtilde = torch.matmul( q_plus_v, p.transpose(-2, -1) ) # (batch, head, time1, time2-dt) BD_noncausal = self._right_shift(BDtilde) BD[:, :, :, dt:] += BD_noncausal # add and normalize scores = (AC + BD) / math.sqrt(self.d_k) # (batch, head, time1, time2) self.attn = self._compute_softmax(scores, mask) return self._apply_attn(v)
[docs]class LocalScaledDotProdAttRelPosEncV1(LocalScaledDotProdAttV1): """Local Scaled dot product multihead attention layer It calculates self-attention between time steps within a window of 'context' frames. It uses relative positional encoders as defined in https://arxiv.org/pdf/1901.02860.pdf Attributes: in_feats: input feature dimension out_feats: output feature dimension num_heads: number of heads d_k: key/query projection dimension d_v: value projection dimension context: maximum attention temporal context. causal_pos_enc: positional encoder is 0 for attending future frames. dropout_rate: dropout rate time_dim: time dimension in the input, default=1 meaning input dimensions are (batch, time, in_feats) """
[docs] def __init__( self, in_feats, out_feats, num_heads, d_k, d_v, context=25, causal_pos_enc=False, dropout_rate=0, time_dim=1, ): super().__init__( in_feats, out_feats, num_heads, d_k, d_v, context, dropout_rate=dropout_rate, time_dim=time_dim, ) self.linear_pos = nn.Linear(in_feats, num_heads * d_k) # u, v in paper, Sec 3.3, 2nd eq. self.u = nn.Parameter(torch.Tensor(num_heads, d_k)) self.v = nn.Parameter(torch.Tensor(num_heads, d_k)) # we use same init as in espnet nn.init.xavier_uniform_(self.u) nn.init.xavier_uniform_(self.v) self.causal_pos_enc = causal_pos_enc self._tril = None self._tril_diag = 0 self._triu = None self._triu_diag = 0
[docs] def _apply_tril(self, x): """Applies lower triangular mask to (Q + v^T) W R_{i-j} attention matrix to keep causal attention points, i.e., i-j >= 0 E.g., if t1=3, t2=4 this will apply a mask [1 1 0 0; 1 1 1 0; 1 1 1 1 ] """ diag = x.size(4) - x.size(3) if ( self._tril is None or self._tril.size(3) < x.size(3) or self._tril.size(4) < x.size(4) or self._tril_diag != diag ): # in these cases we need to recompute the lower triangular mask ones = torch.ones((x.size(3), x.size(4)), dtype=x.dtype, device=x.device) self._tril = torch.tril(ones, diag)[None, None, None, :, :] self._tril_diag = diag tril = self._tril else: tril = self._tril[:, :, :, : x.size(3), : x.size(4)] return x * tril
[docs] def _apply_triu(self, x): """Applies upper triangular mask to (Q + v^T) W R_{i-j} attention matrix to keep non-causal attention points, i.e., i-j < 0 E.g., if t1=3, t2=4 this will apply a mask [0 0 1 1; 0 0 0 1; 0 0 0 0 ] """ # we add 1 to put the diagonal to 0 so we don't count the R_0 embedding twice diag = x.size(4) - x.size(3) + 1 if ( self._triu is None or self._triu.size(3) < x.size(3) or self._triu.size(4) < x.size(4) or self._triu_diag != diag ): # in these cases we need to recompute the lower triangular mask ones = torch.ones((x.size(3), x.size(4)), dtype=x.dtype, device=x.device) self._triu = torch.triu(ones, diag)[None, None, None, :, :] self._triu_diag = diag triu = self._triu else: triu = self._triu[:, :, :, -x.size(3) :, -x.size(4) :] return x * triu
[docs] def _left_shift(self, x, context, left_shift): """Applies left shifts to the rows of x to get scores with relative pos encodings R_{i-j} i-j >=0, causal attention E.g. [q0 R3, q0 R2, q0 R1, q0 R0; q1 R3, q1 R2, q1 R1, q1 R0; q2 R3, q2 R2, q2 R1, q2 R0] becomes: [q0 R1, q0 R0, 0 , 0 ; q1 R2, q1 R1, q1 R0, 0 ; q2 R3, q2 R2, q2 R1, q2 R0] """ if left_shift > 0: right_shift = context - left_shift x = x[:, :, left_shift:-right_shift] x = x.view(x.size(0), x.size(1), -1, context, x.size(-1)) x_pad = nn.functional.pad(x, (1, 0), mode="constant", value=0) x_pad = x_pad.view(*x.size()[:3], x.size(4) + 1, x.size(3)) x = x_pad[:, :, :, 1:].view_as(x) return self._apply_tril(x)
[docs] def _right_shift(self, x, context, left_shift): """Applies right shifts to the rows of x to get scores with relative pos encodings R_{i-j} i-j < 0, non-causal attention E.g. [q0 R_0, q0 R_{-1}, q0 R_{-2}; q1 R_0, q1 R_{-1}, q1 R_{-2}; q2 R_0, q1 R_{-1}, q2 R_{-2}] becomes: [ 0, q0 R_{-1}, q0 R_{-2}; 0, 0 , q1 R_{-1}; 0, 0 , 0 ] """ if left_shift > 0: right_shift = context - left_shift x = x[:, :, left_shift:-right_shift] x = x.view(x.size(0), x.size(1), -1, context, x.size(-1)) x_pad = nn.functional.pad(x, (0, 1), mode="constant", value=0) x_pad = x_pad.view(*x.size()[:3], x.size(4) + 1, x.size(3)) x = x_pad[:, :, :, :-1].view_as(x) return self._apply_triu(x)
[docs] def forward(self, query, key, value, pos_emb=None, mask=None): """Computes 'Scaled Dot Product Attention'. Args: query: query with size=(batch, time1, in_feats), where time1 is the output time dimension key: key with size=(batch, time2, in_feats) where time1 is the input time dimension value: value with size=(batch, time2, in_feats) pos_emb: positional embedding size=(batch, time2, in_feats) as R_{L-1}, ..., R_0 mask: optional mask with size=(batch, time1, time2), to zero attention between some time steps or size=(batch, time) to make time1=time2 Returns: Attention weigthed average of the value with size=(batch, time1, out_feats) """ batch_size = query.size(0) t1 = query.size(self.time_dim) t2 = key.size(self.time_dim) q0, k0, v0, context_q, context_k, num_blocks = self._compute_qkv0( query, key, value ) # q0 size=(batch, time1, head*d_k) # k0 size=(batch, time2, head*d_k) # v0 size=(batch, time2, head*d_v) q_plus_u0 = q0 + self.u.view(-1, q0.size(-1)) # (batch, time1, head*d_k) # q = q.transpose(1, 2) # (batch, time1, head, d_k) # q_plus_u = (q + self.u).transpose(1, 2) #(batch, head, time1, d_k) # q_plus_v = (q + self.v).transpose(1, 2) #(batch, head, time1, d_k) # compute A(a) + A(c) in Sec3.3, 2nd Eq. block diagonals # 1) compute block diagonal affinity matrix AC1 = self._compute_scores( q_plus_u0, k0, num_blocks, context_q, context_k, 0, 0 ) # (batch, head, blocks, context_q, context_k) # 2) compute shifted block diagonal matrix q_left_shift = context_q // 2 k_left_shift = context_k // 2 AC2 = self._compute_scores( q_plus_u0, k0, num_blocks - 1, context_q, context_k, q_left_shift, k_left_shift, ) # (batch, head, blocks-1, context_q, context_k) # AC = torch.matmul(q_plus_u, k.transpose(-2, -1)) # (batch, head, time1, time2) pos_emb = pos_emb[:, -context_k:] # (1, context_k, d_model) pos_batch_size = pos_emb.size(0) p = self.linear_pos(pos_emb).view(pos_batch_size, -1, self.num_heads, self.d_k) p = p.transpose(1, 2) # (1, head, context_k, d_k) q = q0.view( batch_size, -1, self.num_heads, self.d_k ) # (batch, time1, head, d_k) q_plus_v = (q + self.v).transpose(1, 2) # (batch, head, time1, d_k) # compute A(b) + A(d) in Sec3.3, 2nd Eq. for the causal part # This is the sum of Btilde and Dtilde in the Appendix of the paper BDtilde = torch.matmul(q_plus_v, p.transpose(-2, -1)) / math.sqrt( self.d_k ) # (batch, head, time1, context_k) # apply left shift as indicated in the Appendix to geth B+D # 1) block-diagonal part of BD: BD1 BD1 = self._left_shift( BDtilde, context_q, 0 ) # (batch, head, blocks, context_q, context_k) # 2) shifted block diagonal part of BD: BD2 BD2 = self._left_shift( BDtilde, context_q, q_left_shift ) # (batch, head, blocks-1, context_q, context_k) # print('BD\n',BD1[0,0,0,:10,:10]) # print(BD2[0,0,0,:10,:10]) if not self.causal_pos_enc: # compute A(b) + A(d) for the non-causal part, # this is not included in the paper because it doesn't allow to attent to future postions # we assume that t2 >= t1, and therefore context_k >= context_q dt = context_k - context_q pos_emb_noncausal = pos_emb[:, dt:].flip( dims=(1,) ) # we flip to get R_0, ..., R_{L-1} pos_emb_noncausal[ :, :, 0::2 ] *= -1 # we multiply sin emb by -1 to get R_0, R_{-1}, ..., R_{-(L-1)} assert pos_emb[0, -2, 0] == -pos_emb_noncausal[0, 1, 0] p = self.linear_pos(pos_emb_noncausal).view( pos_batch_size, -1, self.num_heads, self.d_k ) p = p.transpose(1, 2) # (batch, head, context_k-dt, d_k) BDtilde = torch.matmul(q_plus_v, p.transpose(-2, -1)) / math.sqrt( self.d_k ) # (batch, head, time1, context_k-dt) BD_noncausal1 = self._right_shift( BDtilde, context_q, 0 ) # (batch, head, blocks, context_q, context_k-dt) BD_noncausal2 = self._right_shift( BDtilde, context_q, q_left_shift ) # (batch, head, blocks-1, context_q, context_k-dt) # print(BD_noncausal1[0,0,0,:10,:10]) # print(BD_noncausal2[0,0,0,:10,:10]) # print('BDshapes', BD1.shape, BD_noncausal1.shape, BD2.shape, BD_noncausal2.shape, BDtilde.shape, dt, context_k, context_q) BD1[:, :, :, :, dt:] += BD_noncausal1 BD2[:, :, :, :, dt:] += BD_noncausal2 # print(BD1[0,0,0,:10,:10]) # print(BD2[0,0,0,:10,:10]) # add AC and BD for block-diag s scores1 = AC1 + BD1 # (batch, head, blocks, context_q, context_k) scores2 = AC2 + BD2 # (batch, head, blocks-1, context_q, context_k) self._compute_softmax( scores1, scores2, mask, q_left_shift, k_left_shift, t1, t2 ) return self._apply_attn(v0, t1)