Source code for hyperion.torch.layers.global_pool

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
import logging
import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as nnf

SQRT_EPS = 1e-5


[docs]def _conv1(in_channels, out_channels, bias=False): """point-wise convolution""" return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)
class _GlobalPool1d(nn.Module): def __init__(self, dim=-1, keepdim=False): super().__init__() self.dim = dim self.keepdim = keepdim self.size_multiplier = 1 def _standarize_weights(self, weights, ndims): if weights.dim() == ndims: return weights assert weights.dim() == 2 shape = ndims * [1] shape[0] = weights.shape[0] shape[self.dim] = weights.shape[1] return weights.view(tuple(shape)) def get_config(self): config = {"dim": self.dim, "keepdim": self.keepdim} return config def forward_slidwin(self, x, win_length, win_shift): raise NotImplementedError() def _slidwin_pad(self, x, win_length, win_shift, snip_edges): if snip_edges: num_frames = int( math.floor((x.size(-1) - win_length + win_shift) / win_shift) ) return nnf.pad(x, (1, 0), mode="constant"), num_frames assert ( win_length >= win_shift ), "if win_length < win_shift snip-edges should be false" num_frames = int(round(x.size(-1) / win_shift)) len_x = (num_frames - 1) * win_shift + win_length dlen_x = round(len_x - x.size(-1)) pad_left = int(math.floor((win_length - win_shift) / 2)) pad_right = int(dlen_x - pad_left) return nnf.pad(x, (pad_left + 1, pad_right), mode="reflect"), num_frames
[docs]class GlobalAvgPool1d(_GlobalPool1d): """Global average pooling in 1d Attributes: dim: pooling dimension keepdim: it True keeps the same number of dimensions after pooling """
[docs] def __init__(self, dim=-1, keepdim=False): super().__init__(dim, keepdim)
[docs] def forward(self, x, weights=None): if weights is None: y = torch.mean(x, dim=self.dim, keepdim=self.keepdim) return y weights = self._standarize_weights(weights, x.dim()) xbar = torch.mean(weights * x, dim=self.dim, keepdim=self.keepdim) wbar = torch.mean(weights, dim=self.dim, keepdim=self.keepdim) return xbar / wbar
[docs] def forward_slidwin(self, x, win_length, win_shift, snip_edges=False): if isinstance(win_shift, int) and isinstance(win_length, int): return self._forward_slidwin_int( x, win_length, win_shift, snip_edges=snip_edges ) # the window length and/or shift are floats return self._forward_slidwin_float( x, win_length, win_shift, snip_edges=snip_edges )
def _pre_slidwin(self, x, win_length, win_shift, snip_edges): if self.dim != -1: x = x.transpose(self.dim, -1) x, num_frames = self._slidwin_pad(x, win_length, win_shift, snip_edges) out_shape = *x.shape[:-1], num_frames c_x = torch.cumsum(x, dim=-1).view(-1, x.shape[-1]) return c_x, out_shape def _post_slidwin(self, m_x, x_shape): m_x = m_x.view(x_shape) if self.dim != -1: m_x = m_x.transpose(self.dim, -1).contiguous() return m_x def _forward_slidwin_int(self, x, win_length, win_shift, snip_edges): c_x, out_shape = self._pre_slidwin(x, win_length, win_shift, snip_edges) m_x = (c_x[:, win_shift:] - c_x[:, :-win_shift]) / win_length m_x = self._post_slid_win(m_x, out_shape) return m_x def _forward_slidwin_float(self, x, win_length, win_shift, snip_edges): c_x, out_shape = self._pre_slidwin(x, win_length, win_shift, snip_edges) num_frames = out_shape[-1] m_x = torch.zeros( (c_x.shape[0], num_frames), dtype=c_x.dtype, device=c_x.device ) k = 0 for i in range(num_frames): k1 = int(round(k)) k2 = int(round(k + win_length)) m_x[:, i] = (c_x[:, k2] - c_x[:, k1]) / (k2 - k1) k += win_shift m_x = self._post_slid_win(m_x, out_shape) return m_x
[docs]class GlobalMeanStdPool1d(_GlobalPool1d): """Global mean + standard deviation pooling in 1d Attributes: dim: pooling dimension keepdim: it True keeps the same number of dimensions after pooling """
[docs] def __init__(self, dim=-1, keepdim=False): super().__init__(dim, keepdim) self.size_multiplier = 2
[docs] def forward(self, x, weights=None): if weights is None: mu = torch.mean(x, dim=self.dim, keepdim=True) delta = x - mu mu.squeeze_(dim=self.dim) # this can produce slightly negative variance when relu6 saturates in all time steps # add 1e-5 for stability s = torch.sqrt( torch.mean(delta ** 2, dim=self.dim, keepdim=False).clamp(min=SQRT_EPS) ) mus = torch.cat((mu, s), dim=1) if self.keepdim: mus.unsqueeze_(dim=self.dim) return mus weights = self._standarize_weights(weights, x.dim()) xbar = torch.mean(weights * x, dim=self.dim, keepdim=True) wbar = torch.mean(weights, dim=self.dim, keepdim=True) mu = xbar / wbar delta = x - mu var = torch.mean(weights * delta ** 2, dim=self.dim, keepdim=True) / wbar s = torch.sqrt(var.clamp(min=SQRT_EPS)) mu = mu.squeeze(self.dim) s = s.squeeze(self.dim) mus = torch.cat((mu, s), dim=1) if self.keepdim: mus.unsqueeze_(dim=self.dim) return mus
[docs] def forward_slidwin(self, x, win_length, win_shift, snip_edges=False): if isinstance(win_shift, int) and isinstance(win_length, int): return self._forward_slidwin_int(x, win_length, win_shift, snip_edges) # the window length and/or shift are floats return self._forward_slidwin_float(x, win_length, win_shift, snip_edges)
def _pre_slidwin(self, x, win_length, win_shift, snip_edges): if self.dim != -1: x = x.transpose(self.dim, -1) x, num_frames = self._slidwin_pad(x, win_length, win_shift, snip_edges) out_shape = *x.shape[:-1], num_frames return x, out_shape def _post_slidwin(self, m_x, s_x, out_shape): m_x = m_x.view(out_shape) s_x = s_x.view(out_shape) mus = torch.cat((m_x, s_x), dim=1) if self.dim != -1: mus = mus.transpose(self.dim, -1).contiguous() return mus def _forward_slidwin_int(self, x, win_length, win_shift, snip_edges): x, out_shape = self._pre_slidwin(x, win_length, win_shift, snip_edges) c_x = torch.cumsum(x, dim=-1).view(-1, x.shape[-1]) m_x = (c_x[:, win_shift:] - c_x[:, :-win_shift]) / win_length c_x = torch.cumsum(x ** 2, dim=-1).view(-1, x.shape[-1]) m_x2 = (c_x[:, win_shift:] - c_x[:, :-win_shift]) / win_length s_x = torch.sqrt(m_x2 - m_x ** 2).clamp(min=SQRT_EPS) mus = self._post_slidwin(m_x, s_x, out_shape) return mus def _forward_slidwin_float(self, x, win_length, win_shift, snip_edges): x, out_shape = self._pre_slidwin(x, win_length, win_shift, snip_edges) num_frames = out_shape[-1] c_x = torch.cumsum(x, dim=-1).view(-1, x.shape[-1]) c_x2 = torch.cumsum(x ** 2, dim=-1).view(-1, x.shape[-1]) # xx = x.view(-1, x.shape[-1]) # print(xx.shape[1]) # print(torch.max(torch.sum(xx==0, dim=1))) m_x = torch.zeros( (c_x.shape[0], num_frames), dtype=c_x.dtype, device=c_x.device ) m_x2 = torch.zeros_like(m_x) k = 0 # max_delta = 0 # max_delta2 = 0 for i in range(num_frames): k1 = int(round(k)) k2 = int(round(k + win_length)) m_x[:, i] = (c_x[:, k2] - c_x[:, k1]) / (k2 - k1) m_x2[:, i] = (c_x2[:, k2] - c_x2[:, k1]) / (k2 - k1) # for j in range(m_x.shape[0]): # m_x_2 = torch.mean(xx[j,k1+1:k2+1]) # m_x2_2 = torch.mean(xx[j,k1+1:k2+1]**2) # delta = torch.abs(m_x_2 - m_x[j,i]).item() # delta2 = torch.abs(m_x2_2 - m_x2[j,i]).item() # if (delta > max_delta or delta2 > max_delta2) and (delta>1e-3 or delta2>1e-3): # max_delta = delta # max_delta2 = delta2 # print('mx', delta, m_x[j,i], m_x_2) # print('mx2', delta2, m_x2[j,i], m_x2_2) # import sys # sys.stdout.flush() # # if m_x[j,i]**2 > m_x2[j,i]: # # print('nan') # # print('mx', m_x[j,i], m_x_2) # # print('mx2', m_x2[j,i], m_x2_2) # # print(c_x[j,k2]) # # print(c_x[j,k1]) # # print(c_x2[j,k2]) # # print(c_x2[j,k1]) # # print(xx[j,k1+1:k2+1]) # # raise Exception() k += win_shift var_x = (m_x2 - m_x ** 2).clamp(min=SQRT_EPS) s_x = torch.sqrt(var_x) # idx = torch.isnan(s_x) #.any(dim=1) # if torch.sum(idx) > 0: # print('sx-nan', s_x[idx]) # print('mx-nan', m_x[idx]) # print('mx2-nan', m_x2[idx]) # print('var-nan', m_x2[idx]-m_x[idx]**2) # #print('cx2-nan', c_x2[idx]) # raise Exception() mus = self._post_slidwin(m_x, s_x, out_shape) return mus
# def _forward_slidwin_int(self, x, win_length, win_shift): # num_frames = int((x.shape[self.dim] - win_length + 2*window_shift -1)/win_shift) # pad_right = win_shift * (num_frames - 1) + win_length # if self.dim != -1: # # put pool dim at the end to do the padding # x = x.transpose(self.dim, -1) # xx = nnf.pad(x, (1, pad_right), mode='reflect') # c_x = torch.cumsum(xx, dim=self.dim).transpose(0, -1) # m_x = (c_x[win_shift:] - c_x[:-win_shift]).transpose(0, self.dim)/win_length # c_x = torch.cumsum(xx**2, dim=-1).transpose(0, -1) # m_x2 = (c_x[win_shift:] - c_x[:-win_shift]).transpose(0, self.dim)/win_length # s_x = torch.sqrt(m_x2 - m_x**2).clamp(min=1e-5) # if self.dim == -1: # return torch.cat((m_x, s_x), dim=-2) # return torch.cat((m_x, s_x), dim=-1) # def _forward_slidwin_float(self, x, win_shift, win_length): # num_frames = int((x.shape[self.dim] - win_length + 2*window_shift -1)/win_shift) # pad_right = win_shift * (num_frames - 1) + win_length # if self.dim != -1: # x = x.transpose(self.dim, -1) # xx = nnf.pad(x, (1, pad_right), mode='reflect') # c_x = torch.cumsum(xx, dim=-1).transpose(0, -1) # c_x2 = torch.cumsum(xx**2, dim=-1).transpose(0, -1) # m_x = [] # m_x2 = [] # k = 0 # for i in range(num_frames): # k1 = int(math.round(k)) # k2 = int(math.round(k+win_length)) # w = (k2-k1) # m_x.append((c_x[k2]-c_x[k1])/w) # m_x2.append((c_x2[k2]-c_x2[k1])/w) # k += win_shift # m_x = m_x.cat(tuple(y), dim=0).transpose(0, self.dim).contiguous() # m_x2 = m_x2.cat(tuple(y), dim=0).transpose(0, self.dim).contiguous() # s_x = torch.sqrt(m_x2 - m_x**2).clamp(min=1e-5) # if self.dim == -1: # return torch.cat((m_x, s_x), dim=-2) # return torch.cat((m_x, s_x), dim=-1)
[docs]class GlobalMeanLogVarPool1d(_GlobalPool1d): """Global mean + log-variance pooling in 1d Attributes: dim: pooling dimension keepdim: it True keeps the same number of dimensions after pooling """
[docs] def __init__(self, dim=-1, keepdim=False): super().__init__(dim, keepdim) self.size_multiplier = 2
[docs] def forward(self, x, weights=None): if weights is None: mu = torch.mean(x, dim=self.dim, keepdim=self.keepdim) x2bar = torch.mean(x ** 2, dim=self.dim, keepdim=self.keepdim) logvar = torch.log(x2bar - mu * mu + 1e-5) # for stability in case var=0 return torch.cat((mu, logvar), dim=-1) weights = self._standarize_weights(weights, x.dim()) xbar = torch.mean(weights * x, dim=self.dim, keepdim=self.keepdim) wbar = torch.mean(weights, dim=self.dim, keepdim=self.keepdim) mu = xbar / wbar x2bar = torch.mean(weights * x ** 2, dim=self.dim, keepdim=self.keepdim) / wbar var = (x2bar - mu * mu).clamp(min=1e-5) logvar = torch.log(var) return torch.cat((mu, logvar), dim=-1)
[docs]class LDEPool1d(_GlobalPool1d): """Learnable dictionary encoder pooling in 1d Attributes: in_feats: input feature dimension num_comp: number of cluster components dist_pow: power for distance metric use_bias: use bias parameter when computing posterior responsibility dim: pooling dimension keepdim: it True keeps the same number of dimensions after pooling """
[docs] def __init__( self, in_feats, num_comp=64, dist_pow=2, use_bias=False, dim=-1, keepdim=False ): super().__init__(dim, keepdim) self.mu = nn.Parameter(torch.randn((num_comp, in_feats))) self.prec = nn.Parameter(torch.ones((num_comp,))) self.use_bias = use_bias if use_bias: self.bias = nn.Parameter(torch.zeros((num_comp,))) else: self.bias = 0 self.dist_pow = dist_pow if dist_pow == 1: self.dist_f = lambda x: torch.norm(x, p=2, dim=-1) else: self.dist_f = lambda x: torch.sum(x ** 2, dim=-1) self.size_multiplier = num_comp
@property def num_comp(self): return self.mu.shape[0] @property def in_feats(self): return self.mu.shape[1] def __repr__(self): return self.__str__() def __str__(self): s = "{}(in_feats={}, num_comp={}, dist_pow={}, use_bias={}, dim={}, keepdim={})".format( self.__class__.__name__, self.mu.shape[1], self.mu.shape[0], self.dist_pow, self.use_bias, self.dim, self.keepdim, ) return s
[docs] def forward(self, x, weights=None): if self.dim != 1 or self.dim != -2: x = x.transpose(1, self.dim) x = torch.unsqueeze(x, dim=2) delta = x - self.mu dist = self.dist_f(delta) llk = -self.prec ** 2 * dist + self.bias r = nnf.softmax(llk, dim=-1) if weights is not None: r *= weights r = torch.unsqueeze(r, dim=-1) N = torch.sum(r, dim=1) + 1e-9 F = torch.sum(r * delta, dim=1) pool = F / N pool = pool.contiguous().view(-1, self.num_comp * self.in_feats) if self.keepdim: if self.dim == 1 or self.dim == -2: pool.unsqueeze_(1) else: pool.unsqueeze_(-1) return pool
[docs] def get_config(self): config = { "in_feats": self.in_feats, "num_comp": self.num_comp, "dist_pow": self.dist_pow, "use_bias": self.use_bias, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
[docs]class ScaledDotProdAttV1Pool1d(_GlobalPool1d):
[docs] def __init__( self, in_feats, num_heads, d_k, d_v, bin_attn=False, dim=-1, keepdim=False ): super().__init__(dim, keepdim) self.d_v = d_v self.d_k = d_k self.num_heads = num_heads self.bin_attn = bin_attn self.q = nn.Parameter(torch.Tensor(1, num_heads, 1, d_k)) nn.init.orthogonal_(self.q) if self.bin_attn: self.bias = nn.Parameter(torch.zeros((1, num_heads, 1, 1))) self.linear_k = nn.Linear(in_feats, num_heads * d_k) self.linear_v = nn.Linear(in_feats, num_heads * d_v) self.attn = None self.size_multiplier = num_heads * d_v / in_feats
@property def in_feats(self): return self.linear_v.in_features def __repr__(self): return self.__str__() def __str__(self): s = "{}(in_feats={}, num_heads={}, d_k={}, d_v={}, bin_attn={}, dim={}, keepdim={})".format( self.__class__.__name__, self.in_feats, self.num_heads, self.d_k, self.d_v, self.bin_attn, self.dim, self.keepdim, ) return s
[docs] def forward(self, x, weights=None): batch_size = x.size(0) if self.dim != 1: x = x.transpose(1, self.dim) k = self.linear_k(x).view(batch_size, -1, self.num_heads, self.d_k) v = self.linear_v(x).view(batch_size, -1, self.num_heads, self.d_v) k = k.transpose(1, 2) # (batch, head, time, d_k) v = v.transpose(1, 2) # (batch, head, time, d_v) scores = torch.matmul(self.q, k.transpose(-2, -1)) / math.sqrt( self.d_k ) # (batch, head, 1, time) if self.bin_attn: scores = torch.sigmoid(scores + self.bias) # scores = scores.squeeze(dim=-1) # (batch, head, time) if weights is not None: mask = weights.view(batch_size, 1, 1, -1).eq(0) # (batch, 1, 1,time) if self.bin_attn: scores = scores.masked_fill(mask, 0.0) self.attn = scores / (torch.sum(scores, dim=-1, keepdim=True) + 1e-9) else: min_value = -1e200 scores = scores.masked_fill(mask, min_value) self.attn = torch.softmax(scores, dim=-1).masked_fill( mask, 0.0 ) # (batch, head, 1, time) else: if self.bin_attn: self.attn = scores / (torch.sum(scores, dim=-1, keepdim=True) + 1e-9) else: self.attn = torch.softmax(scores, dim=-1) # (batch, head, 1, time) x = torch.matmul(self.attn, v) # (batch, head, 1, d_v) if self.keepdim: x = x.view(batch_size, 1, self.num_heads * self.d_v) # (batch, 1, d_model) else: x = x.view(batch_size, self.num_heads * self.d_v) # (batch, d_model) return x
[docs] def get_config(self): config = { "in_feats": self.in_feats, "num_heads": self.num_heads, "d_k": self.d_k, "d_v": self.d_v, "bin_attn": self.bin_attn, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
[docs]class GlobalChWiseAttMeanStdPool1d(_GlobalPool1d): """Attentive mean + stddev pooling for each channel"""
[docs] def __init__( self, in_feats, inner_feats=128, bin_attn=False, use_global_context=True, norm_layer=None, dim=-1, keepdim=False, ): super().__init__(dim, keepdim) self.size_multiplier = 2 self.in_feats = in_feats self.inner_feats = inner_feats self.bin_attn = bin_attn self.use_global_context = use_global_context self.conv1 = _conv1(in_feats, inner_feats) if use_global_context: self.lin_global = nn.Linear(2 * in_feats, inner_feats, bias=False) # torch.autograd.set_detect_anomaly(True) if norm_layer is None: norm_layer = nn.BatchNorm1d self.norm_layer = norm_layer(inner_feats) self.activation = nn.Tanh() self.conv2 = _conv1(inner_feats, in_feats, bias=True) self.stats_pool = GlobalMeanStdPool1d(dim=dim) if self.bin_attn: self.bias = nn.Parameter(torch.ones((1, in_feats, 1)))
def __repr__(self): return self.__str__() def __str__(self): s = "{}(in_feats={}, inner_feats={}, use_global_context={}, bin_attn={}, dim={}, keepdim={})".format( self.__class__.__name__, self.in_feats, self.inner_feats, self.use_global_context, self.bin_attn, self.dim, self.keepdim, ) return s
[docs] def forward(self, x, weights=None): x_inner = self.conv1(x) # logging.info('x_inner1={} {}'.format(torch.sum(torch.isnan(x_inner)), torch.sum(torch.isinf(x_inner)))) if self.use_global_context: global_mus = self.stats_pool(x) x_inner = x_inner + self.lin_global(global_mus).unsqueeze(-1) # logging.info('x_inner2={} {}'.format(torch.sum(torch.isnan(x_inner)), torch.sum(torch.isinf(x_inner)))) attn = self.conv2(self.activation(self.norm_layer(x_inner))) if self.bin_attn: # attn = torch.sigmoid(attn+self.bias) attn = torch.sigmoid(attn) else: attn = nnf.softmax(attn, dim=-1) mus = self.stats_pool(x, weights=attn) # logging.info('mus={} {}'.format(torch.sum(torch.isnan(mus)), torch.sum(torch.isinf(mus)))) return mus
[docs] def get_config(self): config = { "in_feats": self.in_feats, "inner_feats": self.inner_feats, "use_global_context": self.use_global_context, "bin_attn": self.bin_attn, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))