Source code for hyperion.torch.layers.spec_augment

"""
 Copyright 2021 Johns Hopkins University  (Author: Jesus Villalba, Nanxin Chen)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
import logging
from jsonargparse import ArgumentParser, ActionParser

import torch
import torch.nn as nn
import torch.nn.functional as nnf

count = 0


[docs]class AxisMasker(nn.Module): """Applies a mask to the spectrogram along time or freq dimension. Implementation based on espnet. Attributes: mask_width_range: range for the width of the masks mask_num_range: range for the number of masks dim: axis where we apply the mask fill_value: masking value """
[docs] def __init__( self, min_width=0, max_width=30, min_num_masks=1, max_num_masks=2, dim=-1, fill_value=0, ): super().__init__() assert min_width >= 0 assert max_width > 0 assert min_num_masks >= 0 assert max_num_masks > 0 self.min_width = min_width self.max_width = max_width self.min_num_masks = min_num_masks self.max_num_masks = max_num_masks self.dim = dim self.fill_value = fill_value
def __repr__(self): s = ( "{}(min_width={}, max_width={}, " "min_num_masks={}, max_num_masks={}, " "dim={}, fill_value={})" ).format( self.__class__.__name__, self.min_width, self.max_width, self.min_num_masks, self.max_num_masks, self.dim, self.fill_value, ) return s
[docs] def forward(self, x): """Apply mask along time or freq dimension Args: x: spectrogram (batch, *, time, freq) Returns: Masked spectrogram (batch, *, time, freq) """ if not self.training: return x in_shape = x.shape ndim = x.dim() if ndim > 3: x = x.view(-1, x.shape[-2], x.shape[-1]) batch_size = x.shape[0] masked_dim_length = x.shape[self.dim] # select how many masks num_masks = torch.randint( self.min_num_masks, self.max_num_masks + 1, size=(1,), device=x.device )[0] # (batch, num_mask, 1) widths = torch.randint( self.min_width, self.max_width + 1, size=(batch_size, num_masks), device=x.device, ).unsqueeze(-1) max_start_pos = masked_dim_length - torch.max(widths) + 1 # (batch, num_mask, 1) start_pos = torch.randint( 0, max_start_pos, size=(batch_size, num_masks), device=x.device ).unsqueeze(-1) # (1, 1, masked_dim_length) ref = torch.arange(masked_dim_length, device=x.device).view(1, 1, -1) # (batch, num_mask, mask_dim_length) mask = (start_pos <= ref) * (ref < (start_pos + widths)) # (batch, mask_dim_length) mask = mask.any(dim=1) # multiply all masks if self.dim == -1 or self.dim == ndim - 1: mask = mask.unsqueeze(-2) else: mask = mask.unsqueeze(-1) x = x.masked_fill(mask, self.fill_value) if ndim > 3: x = x.view(in_shape) return x
[docs]class SpecWarper(nn.Module): """Warps the spectrogram along time or freq dimension. Implementation based on espnet. Attributes: window: time warp parameter """
[docs] def __init__(self, window=80, mode="bicubic", dim=-2): super().__init__() self.window = window self.mode = mode self.dim = dim
def __repr__(self): s = ("{}(window={}, mode={}, dim={}").format( self.__class__.__name__, self.window, self.mode, self.dim ) return s
[docs] def forward(self, x, lengths=None): """warps x along time or freq dimension Args: x: spectrogram (batch, *, time, freq) lengths: length ratios Returns: warped spectrogram (batch, *, time, freq) """ if not self.training: return x in_shape = x.shape ndim = x.dim() if ndim == 3: x = x.unsqueeze(1) if self.dim > 0: dim = ndim - self.dim else: dim = self.dim # for warping in freq dimension if dim == -1: x = x.transpose(-1, -2) # to make it batcheable we are going to warp # the first n frames where n is the length of the # shortest utterance # the end of the utterance will not be warped if dim == -1 or lengths is None: warp_length = x.shape[-2] else: warp_length = int(x.shape[-2] * torch.min(lengths)) center = torch.randint(self.window, warp_length - self.window, (1,))[0] warped = torch.randint(center - self.window, center + self.window, (1,))[0] + 1 # (batch, C, warped, freq) left = nnf.interpolate( x[:, :, :center], (warped, x.shape[3]), mode=self.mode, align_corners=False ) # (batch, C, time - warped, Freq) right = torch.nn.functional.interpolate( x[:, :, center:warp_length], (warp_length - warped, x.shape[3]), mode=self.mode, align_corners=False, ) if warp_length != x.shape[-2]: right_nowarp = x[:, :, warp_length:] x = torch.cat([left, right, right_nowarp], dim=-2) else: x = torch.cat([left, right], dim=-2) if dim == -1: x = x.transpose(-1, -2) x = x.view(in_shape) return x
[docs]class SpecAugment(nn.Module): """Implementation of SpecAugment. Reference: Daniel S. Park et al. "SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition" Attributes: """
[docs] def __init__( self, time_warp_prob=0, time_warp_window=5, time_warp_mode="bicubic", time_mask_prob=0, time_mask_min_width=0, time_mask_max_width=100, time_mask_min_num_masks=1, time_mask_max_num_masks=2, freq_mask_prob=0, freq_mask_min_width=0, freq_mask_max_width=20, freq_mask_min_num_masks=1, freq_mask_max_num_masks=2, fill_value=0, ): super().__init__() self.time_warp_prob = time_warp_prob self.time_warp_window = time_warp_window self.time_warp_mode = time_warp_mode self.time_mask_prob = time_mask_prob self.time_mask_min_width = time_mask_min_width self.time_mask_max_width = time_mask_max_width self.time_mask_min_num_masks = time_mask_min_num_masks self.time_mask_max_num_masks = time_mask_max_num_masks self.freq_mask_prob = freq_mask_prob self.freq_mask_min_width = freq_mask_min_width self.freq_mask_max_width = freq_mask_max_width self.freq_mask_min_num_masks = freq_mask_min_num_masks self.freq_mask_max_num_masks = freq_mask_max_num_masks self.fill_value = fill_value self.time_masker = None self.freq_masker = None self.time_warper = None if self.time_mask_prob > 0: self.time_masker = AxisMasker( min_width=time_mask_min_width, max_width=time_mask_max_width, min_num_masks=time_mask_min_num_masks, max_num_masks=time_mask_max_num_masks, dim=-2, fill_value=fill_value, ) if self.freq_mask_prob > 0: self.freq_masker = AxisMasker( min_width=freq_mask_min_width, max_width=freq_mask_max_width, min_num_masks=freq_mask_min_num_masks, max_num_masks=freq_mask_max_num_masks, dim=-1, fill_value=fill_value, ) if self.time_warp_prob > 0: self.time_warper = SpecWarper( window=time_warp_window, mode=time_warp_mode, dim=-2 )
def __repr__(self): s = ( "{}(time_warper(p={})={}, time_masker(p={})={}, freq_masker(p={})={})" ).format( self.__class__.__name__, self.time_warp_prob, self.time_warper, self.time_mask_prob, self.time_masker, self.freq_mask_prob, self.freq_masker, ) return s
[docs] def forward(self, x, lengths=None): if not self.training: return x # global count # import matplotlib # import matplotlib.pyplot as plt # plt.figure() # plt.tight_layout() # ax = plt.subplot(221) # ax.imshow(x.cpu().numpy()[0].T) r = torch.rand((3,), device=x.device) if self.time_warp_prob > r[0]: x = self.time_warper(x, lengths) # ax = plt.subplot(222) # ax.imshow(x.cpu().numpy()[0].T) if self.time_mask_prob > r[1]: x = self.time_masker(x) # ax = plt.subplot(223) # ax.imshow(x.cpu().numpy()[0].T) if self.freq_mask_prob > r[2]: x = self.freq_masker(x) # ax = plt.subplot(224) # ax.imshow(x.cpu().numpy()[0].T) # plt.savefig("spec_aug%d.png" % count, dpi=600) # plt.close() # count += 1 return x
[docs] def filter_args(**kwargs): """Filters SpecAugment args from arguments dictionary. Args: kwargs: Arguments dictionary. Returns: Dictionary with SpecAugment options. """ valid_args = ( "time_warp_prob", "time_warp_window", "time_warp_mode", "time_mask_prob", "time_mask_max_width", "time_mask_min_width", "time_mask_max_num_masks", "time_mask_min_num_masks", "freq_mask_prob", "freq_mask_max_width", "freq_mask_min_width", "freq_mask_max_num_masks", "freq_mask_min_num_masks", "fill_value", ) d = dict((k, kwargs[k]) for k in valid_args if k in kwargs) return d
[docs] @staticmethod def add_class_args(parser, prefix=None): """Adds SpecAugment options to parser. Args: parser: Arguments parser prefix: Options prefix. """ if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") parser.add_argument( "--time-warp-prob", type=float, default=0.0, help="prob. for applying warping", ) parser.add_argument( "--time-warp-window", type=int, default=80, help="time warp window param." ) parser.add_argument( "--time-warp-mode", default="bicubic", choices=["bilinear", "linear", "nearest", "bicubic", "trilinear"], help="prob. for applying warping", ) parser.add_argument( "--time-mask-prob", type=float, default=0.0, help="prob. for applying time masking", ) parser.add_argument( "--time-mask-min-width", type=int, default=0, help="min. width for time mask", ) parser.add_argument( "--time-mask-max-width", type=int, default=100, help="max. width for time mask", ) parser.add_argument( "--time-mask-min-num-masks", type=int, default=1, help="min. number of time mask", ) parser.add_argument( "--time-mask-max-num-masks", type=int, default=2, help="max. number of time mask", ) parser.add_argument( "--freq-mask-prob", type=float, default=0.0, help="prob. for applying freq. masking", ) parser.add_argument( "--freq-mask-min-width", type=int, default=0, help="min. width for freq mask", ) parser.add_argument( "--freq-mask-max-width", type=int, default=100, help="max. width for freq mask", ) parser.add_argument( "--freq-mask-min-num-masks", type=int, default=1, help="min. number of freq mask", ) parser.add_argument( "--freq-mask-max-num-masks", type=int, default=2, help="max. number of freq mask", ) parser.add_argument( "--fill-value", type=float, default=0.0, help="filling value for the masked spec. bins", ) if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))