Source code for hyperion.torch.layers.audio_feats

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
#

import math
import logging

from ...utils.misc import str2bool

import torch
import torch.nn as nn
import torch.cuda.amp as amp

try:
    from torch.fft import rfft as torch_rfft

    _rfft = lambda x: torch_rfft(x, dim=-1)
    _pow_spectrogram = lambda x: x.abs() ** 2
    _spectrogram = lambda x: x.abs()
except:
    _rfft = lambda x: torch.rfft(x, 1, normalized=False, onesided=True)
    _pow_spectrogram = lambda x: x.pow(2).sum(-1)
    _spectrogram = lambda x: x.pow(2).sum(-1).sqrt()

from ...feats.filter_banks import FilterBankFactory as FBF

# window types
HAMMING = "hamming"
HANNING = "hanning"
POVEY = "povey"
RECTANGULAR = "rectangular"
BLACKMAN = "blackman"
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]

# def _amp_safe_matmul(a, b):
#     if _use_amp():
#         mx = torch.max(a, dim=-1, keepdim=True)[0]
#         return mx*torch.matmul(a/mx, b)

#     return torch.matmul(a, b)


[docs]def _get_feature_window_function(window_type, window_size, blackman_coeff=0.42): r"""Returns a window function with the given type and size""" if window_type == HANNING: return torch.hann_window(window_size, periodic=True) elif window_type == HAMMING: return torch.hamming_window(window_size, periodic=True, alpha=0.54, beta=0.46) elif window_type == POVEY: # return torch.hann_window(window_size, periodic=True).pow(0.85) a = 2 * math.pi / window_size window_function = torch.arange(window_size, dtype=torch.get_default_dtype()) return (0.5 - 0.5 * torch.cos(a * window_function)).pow(0.85) elif window_type == RECTANGULAR: return torch.ones(window_size, dtype=torch.get_default_dtype()) elif window_type == BLACKMAN: a = 2 * math.pi / window_size window_function = torch.arange(window_size, dtype=torch.get_default_dtype()) return ( blackman_coeff - 0.5 * torch.cos(a * window_function) + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function) ) else: raise Exception("Invalid window type " + window_type)
[docs]def _get_strided_batch(waveform, window_length, window_shift, snip_edges, center=False): r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) representing how the window is shifted along the waveform. Each row is a frame. Args: waveform (torch.Tensor): Tensor of size ``num_samples`` window_size (int): Frame length window_shift (int): Frame shift snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame_length. If False, the number of frames depends only on the frame_shift, and we reflect the data at the ends. center (bool): If true, if puts the center of the frame at t*window_shift, starting at t=0, If overwrides snip_edges and set it to False Returns: torch.Tensor: 3D tensor of size (m, ``window_size``) where each row is a frame """ assert waveform.dim() == 2 batch_size = waveform.size(0) num_samples = waveform.size(-1) if center: snip_edges = False if snip_edges: if num_samples < window_length: return torch.empty((0, 0, 0)) else: num_frames = 1 + (num_samples - window_length) // window_shift else: if center: npad_left = int(window_length // 2) npad_right = npad_left npad = 2 * npad_left num_frames = 1 + (num_samples + npad - window_length) // window_shift else: num_frames = (num_samples + (window_shift // 2)) // window_shift new_num_samples = (num_frames - 1) * window_shift + window_length npad = new_num_samples - num_samples npad_left = int((window_length - window_shift) // 2) npad_right = npad - npad_left # waveform = nn.functional.pad(waveform, (npad_left, npad_right), mode='reflect') pad_left = torch.flip(waveform[:, 1 : npad_left + 1], (1,)) pad_right = torch.flip(waveform[:, -npad_right - 1 : -1], (1,)) waveform = torch.cat((pad_left, waveform, pad_right), dim=1) strides = ( waveform.stride(0), window_shift * waveform.stride(1), waveform.stride(1), ) sizes = (batch_size, num_frames, window_length) return waveform.as_strided(sizes, strides)
[docs]def _get_log_energy(x, energy_floor): r"""Returns the log energy of size (m) for a strided_input (m,*)""" log_energy = (x.pow(2).sum(-1) + 1e-15).log() # size (m) if energy_floor > 0.0: log_energy = torch.max( log_energy, torch.tensor(math.log(energy_floor), dtype=torch.get_default_dtype()), ) return log_energy
[docs]class Wav2Win(nn.Module):
[docs] def __init__( self, fs=16000, frame_length=25, frame_shift=10, pad_length=None, remove_dc_offset=True, preemph_coeff=0.97, window_type="povey", dither=1, snip_edges=True, center=False, energy_floor=0, raw_energy=True, return_log_energy=False, ): super().__init__() self.fs = fs self.frame_length = frame_length self.frame_shift = frame_shift self.remove_dc_offset = remove_dc_offset self.preemph_coeff = preemph_coeff self.window_type = window_type self.dither = dither self.snip_edges = snip_edges self.center = center self.energy_floor = energy_floor self.raw_energy = raw_energy self.return_log_energy = return_log_energy N = int(math.floor(frame_length * fs / 1000)) self._length = N self._shift = int(math.floor(frame_shift * fs / 1000)) self._window = nn.Parameter( _get_feature_window_function(window_type, N), requires_grad=False ) self.pad_length = N if pad_length is None else pad_length assert self.pad_length >= N
def __repr__(self): return self.__str__() def __str__(self): s = ( "{}(fs={}, frame_length={}, frame_shift={}, pad_length={}, " "remove_dc_offset={}, preemph_coeff={}, window_type={} " "dither={}, snip_edges={}, center={}, energy_floor={}, raw_energy={}, return_log_energy={})" ).format( self.__class__.__name__, self.fs, self.frame_length, self.frame_shift, self.pad_length, self.remove_dc_offset, self.preemph_coeff, self.window_type, self.dither, self.snip_edges, self.center, self.energy_floor, self.raw_energy, self.return_log_energy, ) return s
[docs] def forward(self, x): # Add dither if self.dither != 0.0: n = torch.randn(x.shape, device=x.device) x = x + self.dither * n # remove offset if self.remove_dc_offset: mu = torch.mean(x, dim=1, keepdim=True) x = x - mu if self.return_log_energy and self.raw_energy: # Compute the log energy of each frame x_strided = _get_strided_batch( x, self._length, self._shift, self.snip_edges, center=self.center ) log_energy = _get_log_energy(x_strided, self.energy_floor) # size (m) if self.preemph_coeff != 0.0: x_offset = torch.nn.functional.pad( x.unsqueeze(1), (1, 0), mode="replicate" ).squeeze(1) x = x - self.preemph_coeff * x_offset[:, :-1] x_strided = _get_strided_batch( x, self._length, self._shift, self.snip_edges, center=self.center ) # Apply window_function to each frame x_strided = x_strided * self._window if self.return_log_energy and not self.raw_energy: signal_log_energy = _get_log_energy( strided_input, self.energy_floor ) # size (batch, m) # Pad columns with zero until we reach size (batch, num_frames, pad_length) if self.pad_length != self._length: pad = self.pad_length - self._length x_strided = torch.nn.functional.pad( x_strided.unsqueeze(1), (0, pad), mode="constant", value=0 ).squeeze(1) if self.return_log_energy: return x_strided, log_energy return x_strided
[docs]class Wav2FFT(nn.Module):
[docs] def __init__( self, fs=16000, frame_length=25, frame_shift=10, fft_length=512, remove_dc_offset=True, preemph_coeff=0.97, window_type="povey", dither=1, snip_edges=True, center=False, energy_floor=0, raw_energy=True, use_energy=True, ): super().__init__() N = int(math.floor(frame_length * fs / 1000)) if N > fft_length: k = math.ceil(math.log(N) / math.log(2)) self.fft_length = int(2 ** k) self.wav2win = Wav2Win( fs, frame_length, frame_shift, pad_length=fft_length, remove_dc_offset=remove_dc_offset, preemph_coeff=preemph_coeff, window_type=window_type, dither=dither, snip_edges=snip_edges, center=center, energy_floor=0, raw_energy=raw_energy, return_log_energy=use_energy, ) self.fft_length = fft_length self.use_energy = use_energy
@property def fs(self): return self.wav2win.fs @property def frame_length(self): return self.wav2win.frame_length @property def frame_shift(self): return self.wav2win.frame_shift @property def remove_dc_offset(self): return self.wav2win.remove_dc_offset @property def preemph_coeff(self): return self.wav2win.preemph_coeff @property def window_type(self): return self.wav2win.window_type @property def dither(self): return self.wav2win.dither
[docs] def forward(self, x): x_strided = self.wav2win(x) if self.use_energy: x_strided, log_e = x_strided # X = torch.rfft(x_strided, 1, normalized=False, onesided=True) X = _rfft(x_strided) if self.use_energy: X[:, 0, :, 0] = log_e return X
[docs]class Wav2Spec(Wav2FFT):
[docs] def __init__( self, fs=16000, frame_length=25, frame_shift=10, fft_length=512, remove_dc_offset=True, preemph_coeff=0.97, window_type="povey", use_fft_mag=False, dither=1, snip_edges=True, center=False, energy_floor=0, raw_energy=True, use_energy=True, ): super().__init__( fs, frame_length, frame_shift, fft_length, remove_dc_offset=remove_dc_offset, preemph_coeff=preemph_coeff, window_type=window_type, dither=dither, snip_edges=snip_edges, center=center, energy_floor=energy_floor, raw_energy=raw_energy, use_energy=use_energy, ) self.use_fft_mag = use_fft_mag if use_fft_mag: self._to_spec = _spectrogram else: self._to_spec = _pow_spectrogram
[docs] def forward(self, x): x_strided = self.wav2win(x) if self.use_energy: x_strided, log_e = x_strided # X = torch.rfft(x_strided, 1, normalized=False, onesided=True) X = _rfft(x_strided) pow_spec = self._to_spec(X) # pow_spec = X.pow(2).sum(-1) # if self.use_fft_mag: # pow_spec = pow_spec.sqrt() if self.use_energy: pow_spec[:, 0] = log_e return pow_spec
[docs]class Wav2LogSpec(Wav2FFT):
[docs] def __init__( self, fs=16000, frame_length=25, frame_shift=10, fft_length=512, remove_dc_offset=True, preemph_coeff=0.97, window_type="povey", use_fft_mag=False, dither=1, snip_edges=True, center=False, energy_floor=0, raw_energy=True, use_energy=True, ): super().__init__( fs, frame_length, frame_shift, fft_length, remove_dc_offset=remove_dc_offset, preemph_coeff=preemph_coeff, window_type=window_type, dither=dither, snip_edges=snip_edges, center=center, energy_floor=energy_floor, raw_energy=raw_energy, use_energy=use_energy, ) self.use_fft_mag = use_fft_mag if use_fft_mag: self._to_spec = _spectrogram else: self._to_spec = _pow_spectrogram
[docs] def forward(self, x): x_strided = self.wav2win(x) if self.use_energy: x_strided, log_e = x_strided # X = torch.rfft(x_strided, 1, normalized=False, onesided=True) X = _rfft(x_strided) pow_spec = self._to_spec(X) # pow_spec = X.pow(2).sum(-1) # if self.use_fft_mag: # pow_spec = pow_spec.sqrt() pow_spec = (pow_spec + 1e-15).log() if self.use_energy: pow_spec[:, 0] = log_e return pow_spec
[docs]class Wav2LogFilterBank(Wav2FFT):
[docs] def __init__( self, fs=16000, frame_length=25, frame_shift=10, fft_length=512, remove_dc_offset=True, preemph_coeff=0.97, window_type="povey", use_fft_mag=False, dither=1, fb_type="mel_kaldi", low_freq=20, high_freq=0, num_filters=23, norm_filters=False, snip_edges=True, center=False, energy_floor=0, raw_energy=True, use_energy=True, ): super().__init__( fs, frame_length, frame_shift, fft_length, remove_dc_offset=remove_dc_offset, preemph_coeff=preemph_coeff, window_type=window_type, dither=dither, snip_edges=snip_edges, center=center, energy_floor=energy_floor, raw_energy=raw_energy, use_energy=use_energy, ) self.use_fft_mag = use_fft_mag self.fb_type = fb_type self.low_freq = low_freq self.high_freq = high_freq self.num_filters = num_filters self.norm_filters = norm_filters fb = FBF.create( fb_type, num_filters, self.fft_length, self.fs, low_freq, high_freq, norm_filters, ) self._fb = nn.Parameter( torch.tensor(fb, dtype=torch.get_default_dtype()), requires_grad=False ) if use_fft_mag: self._to_spec = _spectrogram else: self._to_spec = _pow_spectrogram
[docs] def forward(self, x): x_strided = self.wav2win(x) if self.use_energy: x_strided, log_e = x_strided # X = torch.rfft(x_strided, 1, normalized=False, onesided=True) X = _rfft(x_strided) # logging.info('X={} {}'.format(X, X.type())) # logging.info('X={}'.format(X.type())) pow_spec = self._to_spec(X) # pow_spec = X.pow(2).sum(-1) # # logging.info('p={} {} nan={}'.format(pow_spec, pow_spec.type(), torch.sum(torch.isnan(pow_spec)))) # # logging.info('p={}'.format(pow_spec.type())) # if self.use_fft_mag: # pow_spec = pow_spec.sqrt() with amp.autocast(enabled=False): pow_spec = torch.matmul(pow_spec.float(), self._fb.float()) # logging.info('fb={} {}'.format(pow_spec, pow_spec.type())) # logging.info('fb={}'.format(pow_spec.type())) pow_spec = (pow_spec + 1e-10).log() # logging.info('lfb={} {}'.format(pow_spec, pow_spec.type())) # logging.info('lfb={}'.format(pow_spec.type())) if self.use_energy: pow_spec = torch.cat((log_e.unsqueeze(-1), pow_spec), dim=-1) return pow_spec
[docs]class Wav2MFCC(Wav2FFT):
[docs] def __init__( self, fs=16000, frame_length=25, frame_shift=10, fft_length=512, remove_dc_offset=True, preemph_coeff=0.97, window_type="povey", use_fft_mag=False, dither=1, fb_type="mel_kaldi", low_freq=20, high_freq=0, num_filters=23, norm_filters=False, num_ceps=13, snip_edges=True, center=False, cepstral_lifter=22, energy_floor=0, raw_energy=True, use_energy=True, ): super().__init__( fs, frame_length, frame_shift, fft_length, remove_dc_offset=remove_dc_offset, preemph_coeff=preemph_coeff, window_type=window_type, dither=dither, snip_edges=snip_edges, center=center, energy_floor=energy_floor, raw_energy=raw_energy, use_energy=use_energy, ) self.use_fft_mag = use_fft_mag self.fb_type = fb_type self.low_freq = low_freq self.high_freq = high_freq self.num_filters = num_filters self.norm_filters = norm_filters self.num_ceps = num_ceps self.cepstral_lifter = cepstral_lifter fb = FBF.create( fb_type, num_filters, self.fft_length, self.fs, low_freq, high_freq, norm_filters, ) self._fb = nn.Parameter( torch.tensor(fb, dtype=torch.get_default_dtype()), requires_grad=False ) self._dct = nn.Parameter( self.make_dct_matrix(self.num_ceps, self.num_filters), requires_grad=False ) self._lifter = nn.Parameter( self.make_lifter(self.num_ceps, self.cepstral_lifter), requires_grad=False ) if use_fft_mag: self._to_spec = _spectrogram else: self._to_spec = _pow_spectrogram
[docs] @staticmethod def make_lifter(N, Q): """Makes the liftering function Args: N: Number of cepstral coefficients. Q: Liftering parameter Returns: Liftering vector. """ if Q == 0: return 1 return 1 + 0.5 * Q * torch.sin( math.pi * torch.arange(N, dtype=torch.get_default_dtype()) / Q )
[docs] @staticmethod def make_dct_matrix(num_ceps, num_filters): n = torch.arange(float(num_filters)).unsqueeze(1) k = torch.arange(float(num_ceps)) dct = torch.cos( math.pi / float(num_filters) * (n + 0.5) * k ) # size (n_mfcc, n_mels) dct[:, 0] *= 1.0 / math.sqrt(2.0) dct *= math.sqrt(2.0 / float(num_filters)) return dct
[docs] def forward(self, x): x_strided = self.wav2win(x) if self.use_energy: x_strided, log_e = x_strided # X = torch.rfft(x_strided, 1, normalized=False, onesided=True) X = _rfft(x_strided) pow_spec = self._to_spec(X) # pow_spec = X.pow(2).sum(-1) # if self.use_fft_mag: # pow_spec = pow_spec.sqrt() with amp.autocast(enabled=False): pow_spec = torch.matmul(pow_spec.float(), self._fb.float()) pow_spec = (pow_spec + 1e-10).log() mfcc = torch.matmul(pow_spec, self._dct) if self.cepstral_lifter > 0: mfcc *= self._lifter if self.use_energy: mfcc[:, 0] = log_e return mfcc
[docs]class Wav2KanBayashiLogFilterBank(Wav2LogFilterBank): """Class to replicate log-filter-banks used in Kan Bayashi's ParallelWaveGAN repository: https://github.com/kan-bayashi/ParallelWaveGAN """
[docs] def __init__( self, fs=16000, frame_length=64, frame_shift=16, fft_length=1024, remove_dc_offset=True, window_type="hanning", low_freq=80, high_freq=7600, num_filters=80, snip_edges=False, center=True, ): super().__init__( fs=fs, frame_length=frame_length, frame_shift=frame_shift, fft_length=fft_length, remove_dc_offset=remove_dc_offset, preemph_coeff=0, window_type=window_type, use_fft_mag=True, dither=1e-5, fb_type="mel_librosa", low_freq=low_freq, high_freq=high_freq, num_filters=num_filters, norm_filters=True, snip_edges=snip_edges, center=center, use_energy=False, ) # Kan Bayashi uses log10 instead of log self.scale = 1.0 / math.log(10)
[docs] def forward(self, x): return self.scale * super().forward(x)
[docs]class Spec2LogFilterBank:
[docs] def __init__( self, fs=16000, fft_length=512, fb_type="mel_kaldi", low_freq=20, high_freq=0, num_filters=23, norm_filters=False, ): super().__init__() self.fs = fs self.fft_length = fft_length self.fb_type = fb_type self.low_freq = low_freq self.high_freq = high_freq self.num_filters = num_filters self.norm_filters = norm_filters fb = FBF.create( fb_type, num_filters, self.fft_length, self.fs, low_freq, high_freq, norm_filters, ) self._fb = nn.Parameter( torch.tensor(fb, dtype=torch.get_default_dtype()), requires_grad=False )
[docs] def forward(self, x): with amp.autocast(enabled=False): pow_spec = torch.matmul(x.float(), self._fb.float()) pow_spec = (pow_spec + 1e-10).log() return pow_spec