"""
Copyright 2020 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import logging
from jsonargparse import ArgumentParser, ActionParser
import time
import math
import numpy as np
import pandas as pd
import torch
from ..torch_defs import floatstr_torch
from ...io import RandomAccessAudioReader as AR
from ...utils.utt2info import Utt2Info
from ...augment import SpeechAugment
from torch.utils.data import Dataset
import torch.distributed as dist
[docs]class AudioDataset(Dataset):
[docs] def __init__(
self,
audio_path,
key_file,
class_file=None,
time_durs_file=None,
min_chunk_length=1,
max_chunk_length=None,
aug_cfg=None,
return_fullseqs=False,
return_class=True,
return_clean_aug_pair=False,
transpose_input=False,
wav_scale=2 ** 15 - 1,
is_val=False,
):
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
self.rank = rank
self.world_size = world_size
if rank == 0:
logging.info("opening dataset %s" % audio_path)
self.r = AR(audio_path, wav_scale=wav_scale)
if rank == 0:
logging.info("loading utt2info file %s" % key_file)
self.u2c = Utt2Info.load(key_file, sep=" ")
if rank == 0:
logging.info("dataset contains %d seqs" % self.num_seqs)
self.is_val = is_val
self._read_time_durs_file(time_durs_file)
# self._seq_lengths = self.r.read_time_duration(self.u2c.key)
self._prune_short_seqs(min_chunk_length)
self.short_seq_exist = self._seq_shorter_than_max_length_exists(
max_chunk_length
)
self._prepare_class_info(class_file)
if max_chunk_length is None:
max_chunk_length = min_chunk_length
self._min_chunk_length = min_chunk_length
self._max_chunk_length = max_chunk_length
self.return_fullseqs = return_fullseqs
self.return_class = return_class
self.return_clean_aug_pair = return_clean_aug_pair
self.transpose_input = transpose_input
self.augmenter = None
self.reverb_context = 0
if aug_cfg is not None:
self.augmenter = SpeechAugment.create(
aug_cfg, random_seed=112358 + 1000 * rank
)
self.reverb_context = self.augmenter.max_reverb_context
def _read_time_durs_file(self, file_path):
if self.rank == 0:
logging.info("reading time_durs file %s" % file_path)
nf_df = pd.read_csv(file_path, header=None, sep=" ")
nf_df.index = nf_df[0]
self._seq_lengths = nf_df.loc[self.u2c.key, 1].values
@property
def wav_scale(self):
return self.r.wav_scale
@property
def num_seqs(self):
return len(self.u2c)
def __len__(self):
return self.num_seqs
@property
def seq_lengths(self):
return self._seq_lengths
@property
def total_length(self):
return np.sum(self.seq_lengths)
@property
def min_chunk_length(self):
if self.return_fullseqs:
self._min_chunk_length = np.min(self.seq_lengths)
return self._min_chunk_length
@property
def max_chunk_length(self):
if self._max_chunk_length is None:
self._max_chunk_length = np.max(self.seq_lengths)
return self._max_chunk_length
@property
def min_seq_length(self):
return np.min(self.seq_lengths)
@property
def max_seq_length(self):
return np.max(self.seq_lengths)
def _prune_short_seqs(self, min_length):
if self.rank == 0:
logging.info("pruning short seqs")
keep_idx = self.seq_lengths >= min_length
self.u2c = self.u2c.filter_index(keep_idx)
self._seq_lengths = self.seq_lengths[keep_idx]
if self.rank == 0:
logging.info(
"pruned seqs with min_length < %f,"
"keep %d/%d seqs" % (min_length, self.num_seqs, len(keep_idx))
)
def _prepare_class_info(self, class_file):
class_weights = None
if class_file is None:
classes, class_idx = np.unique(self.u2c.info, return_inverse=True)
class2idx = {k: i for i, k in enumerate(classes)}
else:
if self.rank == 0:
logging.info("reading class-file %s" % (class_file))
class_info = pd.read_csv(class_file, header=None, sep=" ")
class2idx = {str(k): i for i, k in enumerate(class_info[0])}
class_idx = np.array([class2idx[k] for k in self.u2c.info], dtype=int)
if class_info.shape[1] == 2:
class_weights = np.array(class_info[1]).astype(
floatstr_torch(), copy=False
)
self.num_classes = len(class2idx)
class2utt_idx = {}
class2num_utt = np.zeros((self.num_classes,), dtype=int)
for k in range(self.num_classes):
idx = (class_idx == k).nonzero()[0]
class2utt_idx[k] = idx
class2num_utt[k] = len(idx)
if class2num_utt[k] == 0:
if not self.is_val:
logging.warning("class %d doesn't have any samples" % (k))
if class_weights is None:
class_weights = np.ones((self.num_classes,), dtype=floatstr_torch())
class_weights[k] = 0
count_empty = np.sum(class2num_utt == 0)
if count_empty > 0:
logging.warning("%d classes have 0 samples" % (count_empty))
self.utt_idx2class = class_idx
self.class2utt_idx = class2utt_idx
self.class2num_utt = class2num_utt
if class_weights is not None:
class_weights /= np.sum(class_weights)
class_weights = torch.Tensor(class_weights)
self.class_weights = class_weights
if self.short_seq_exist:
# if there are seq shorter than max_chunk_lenght we need some extra variables
# we will need class_weights to put to 0 classes that have all utts shorter than the batch chunk length
if self.class_weights is None:
self.class_weights = torch.ones((self.num_classes,))
# we need the max length of the utterances of each class
class2max_length = torch.zeros((self.num_classes,), dtype=torch.float)
for c in range(self.num_classes):
if class2num_utt[c] > 0:
class2max_length[c] = np.max(
self.seq_lengths[self.class2utt_idx[c]]
)
self.class2max_length = class2max_length
def _seq_shorter_than_max_length_exists(self, max_length):
return np.any(self.seq_lengths < max_length)
@property
def var_chunk_length(self):
return self.min_chunk_length < self.max_chunk_length
[docs] def get_random_chunk_length(self):
if self.var_chunk_length:
return (
torch.rand(size=(1,)).item()
* (self.max_chunk_length - self.min_chunk_length)
+ self.min_chunk_length
)
return self.max_chunk_length
def __getitem__(self, index):
# logging.info('{} {} {} get item {}'.format(
# self, os.getpid(), threading.get_ident(), index))
if self.return_fullseqs:
return self._get_fullseq(index)
else:
return self._get_random_chunk(index)
def _get_fullseq(self, index):
key = self.u2c.key[index]
x, fs = self.r.read([key])
x = x[0].astype(floatstr_torch(), copy=False)
x_clean = x
if self.augmenter is not None:
x, aug_info = self.augmenter(x)
if self.transpose_input:
x = x[None, :]
if self.return_clean_aug_pair:
x_clean = x_clean[None, :]
if self.return_clean_aug_pair:
r = x, x_clean
if not self.return_class:
return r
class_idx = self.utt_idx2class[index]
r = *r, class_idx
return r
def _get_random_chunk(self, index):
if len(index) == 2:
index, chunk_length = index
else:
chunk_length = self.max_chunk_length
key = self.u2c.key[index]
full_seq_length = self.seq_lengths[index]
assert (
chunk_length <= full_seq_length
), "chunk_length(%d) <= full_seq_length(%d)" % (chunk_length, full_seq_length)
time_offset = torch.rand(size=(1,)).item() * (full_seq_length - chunk_length)
reverb_context = min(self.reverb_context, time_offset)
time_offset -= reverb_context
read_chunk_length = chunk_length + reverb_context
# logging.info('get-random-chunk {} {} {} {} {}'.format(index, key, time_offset, chunk_length, full_seq_length ))
x, fs = self.r.read([key], time_offset=time_offset, time_durs=read_chunk_length)
# try:
# x, fs = self.r.read([key], time_offset=time_offset,
# time_durs=read_chunk_length)
# except:
# # some files produce error in the fseek after reading the data,
# # this seems an issue from pysoundfile or soundfile lib itself
# # reading from a sligthly different starting position seems to solve the problem in most cases
# try:
# logging.info('error-1 reading at key={} totol_dur={} offset={} read_chunk_length={}, retrying...'.format(
# key, full_seq_length, time_offset, read_chunk_length))
# time_offset = math.floor(time_offset)
# x, fs = self.r.read([key], time_offset=time_offset,
# time_durs=read_chunk_length)
# except:
# try:
# # if changing the value of time-offset doesn't solve the issue, we try to read from
# # from time-offset to the end of the file, and remove the extra frames later
# logging.info('error-2 reading at key={} totol_dur={} offset={} retrying reading until end-of-file ...'.format(
# key, full_seq_length, time_offset))
# x, fs = self.r.read([key], time_offset=time_offset)
# x = [x[0][:int(read_chunk_length * fs[0])]]
# except:
# # try to read the full file
# logging.info('error-3 reading at key={} totol_dur={} retrying reading full file ...'.format(
# key, full_seq_length))
# x, fs = self.r.read([key])
# x = [x[0][:int(read_chunk_length * fs[0])]]
x = x[0]
fs = fs[0]
x_clean = x
logging.info("hola1")
if self.augmenter is not None:
logging.info("hola2")
chunk_length_samples = int(chunk_length * fs)
end_idx = len(x)
reverb_context_samples = end_idx - chunk_length_samples
assert reverb_context_samples >= 0, (
"key={} time-offset={}, read-chunk={} "
"read-x-samples={}, chunk_samples={}, reverb_context_samples={}"
).format(
key,
time_offset,
read_chunk_length,
end_idx,
chunk_length_samples,
reverb_context_samples,
)
# end_idx = reverb_context_samples + chunk_length_samples
x, aug_info = self.augmenter(x)
x = x[reverb_context_samples:end_idx]
if self.return_clean_aug_pair:
x_clean = x_clean[reverb_context_samples:end_idx]
x_clean = x_clean.astype(floatstr_torch(), copy=False)
# x_clean = x_clean[reverb_context_samples:]
# logging.info('augmentation x-clean={}, x={}, aug_info={}'.format(
# x_clean.shape, x.shape, aug_info))
# if len(x) != 64000:
# logging.info('x!=4s, {} {} {} {} {} {} {} {}'.format(len(x),reverb_context, reverb_context_samples, chunk_length, chunk_length_samples, end_idx, fs, read_chunk_length))
# if len(x) != 64000:
# logging.info('x!=4s-2, {} {} {} {}'.format(len(x), chunk_length, fs, read_chunk_length))
if self.transpose_input:
x = x[None, :]
if self.return_clean_aug_pair:
x_clean = x_clean[None, :]
x = x.astype(floatstr_torch(), copy=False)
if self.return_clean_aug_pair:
r = x, x_clean
else:
r = (x,)
if not self.return_class:
return r
class_idx = self.utt_idx2class[index]
r = *r, class_idx
return r
[docs] @staticmethod
def filter_args(**kwargs):
ar_args = AR.filter_args(**kwargs)
valid_args = (
"path_prefix",
"class_file",
"time_durs_file",
"min_chunk_length",
"max_chunk_length",
"return_fullseqs",
"part_idx",
"num_parts",
)
args = dict((k, kwargs[k]) for k in valid_args if k in kwargs)
args.update(ar_args)
return args
[docs] @staticmethod
def add_class_args(parser, prefix=None):
if prefix is not None:
outer_parser = parser
parser = ArgumentParser(prog="")
# parser.add_argument('--path-prefix',
# default='',
# help=('path prefix for rspecifier scp file'))
parser.add_argument(
"--class-file",
default=None,
help=("ordered list of classes keys, it can contain class weights"),
)
parser.add_argument(
"--time-durs-file", default=None, help=("utt to duration in secs file")
)
parser.add_argument(
"--min-chunk-length",
type=float,
default=None,
help=("minimum length of sequence chunks"),
)
parser.add_argument(
"--max-chunk-length",
type=float,
default=None,
help=("maximum length of sequence chunks"),
)
parser.add_argument(
"--return-fullseqs",
default=False,
action="store_true",
help=("returns full sequences instead of chunks"),
)
AR.add_class_args(parser)
if prefix is not None:
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='audio dataset options')
add_argparse_args = add_class_args