Source code for hyperion.torch.data.feat_seq_dataset

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import sys
import os
import logging
from jsonargparse import ArgumentParser, ActionParser
import time
import copy
import threading

import numpy as np
import pandas as pd

import torch

from ..torch_defs import floatstr_torch
from ...io import RandomAccessDataReaderFactory as RF
from ...utils.utt2info import Utt2Info

from torch.utils.data import Dataset


[docs]class FeatSeqDataset(Dataset):
[docs] def __init__( self, rspecifier, key_file, class_file=None, num_frames_file=None, path_prefix=None, min_chunk_length=1, max_chunk_length=None, return_fullseqs=False, return_class=True, transpose_input=True, is_val=False, ): logging.info("opening dataset %s" % rspecifier) self.r = RF.create(rspecifier, path_prefix=path_prefix, scp_sep=" ") logging.info("loading utt2info file %s" % key_file) self.u2c = Utt2Info.load(key_file, sep=" ") logging.info("dataset contains %d seqs" % self.num_seqs) self.is_val = is_val self._seq_lengths = None if num_frames_file is not None: self._read_num_frames_file(num_frames_file) self._prune_short_seqs(min_chunk_length) self.short_seq_exist = self._seq_shorter_than_max_length_exists( max_chunk_length ) self._prepare_class_info(class_file) if max_chunk_length is None: max_chunk_length = min_chunk_length self._min_chunk_length = min_chunk_length self._max_chunk_length = max_chunk_length self.return_fullseqs = return_fullseqs self.return_class = return_class self.transpose_input = transpose_input
def _read_num_frames_file(self, file_path): logging.info("reading num_frames file %s" % file_path) nf_df = pd.read_csv(file_path, header=None, sep=" ") nf_df.index = nf_df[0] self._seq_lengths = nf_df.loc[self.u2c.key, 1].values @property def num_seqs(self): return len(self.u2c) def __len__(self): return self.num_seqs @property def seq_lengths(self): if self._seq_lengths is None: self._seq_lengths = self.r.read_num_rows(self.u2c.key) return self._seq_lengths @property def total_length(self): return np.sum(self.seq_lengths) @property def min_chunk_length(self): if self.return_fullseqs: self._min_chunk_length = np.min(self.seq_lengths) return self._min_chunk_length @property def max_chunk_length(self): if self._max_chunk_length is None: self._max_chunk_length = np.max(self.seq_lengths) return self._max_chunk_length @property def min_seq_length(self): return np.min(self.seq_lengths) @property def max_seq_length(self): return np.max(self.seq_lengths) def _prune_short_seqs(self, min_length): logging.info("pruning short seqs") keep_idx = self.seq_lengths >= min_length self.u2c = self.u2c.filter_index(keep_idx) self._seq_lengths = self.seq_lengths[keep_idx] logging.info( "pruned seqs with min_length < %d," "keep %d/%d seqs" % (min_length, self.num_seqs, len(keep_idx)) ) def _prepare_class_info(self, class_file): class_weights = None if class_file is None: classes, class_idx = np.unique(self.u2c.info, return_inverse=True) class2idx = {k: i for i, k in enumerate(classes)} else: logging.info("reading class-file %s" % (class_file)) class_info = pd.read_csv(class_file, header=None, sep=" ") class2idx = {str(k): i for i, k in enumerate(class_info[0])} class_idx = np.array([class2idx[k] for k in self.u2c.info], dtype=int) if class_info.shape[1] == 2: class_weights = np.array(class_info[1]).astype( floatstr_torch(), copy=False ) self.num_classes = len(class2idx) class2utt_idx = {} class2num_utt = np.zeros((self.num_classes,), dtype=int) for k in range(self.num_classes): idx = (class_idx == k).nonzero()[0] class2utt_idx[k] = idx class2num_utt[k] = len(idx) if class2num_utt[k] == 0: if not self.is_val: logging.warning("class %d doesn't have any samples" % (k)) if class_weights is None: class_weights = np.ones((self.num_classes,), dtype=floatstr_torch()) class_weights[k] = 0 count_empty = np.sum(class2num_utt == 0) if count_empty > 0: logging.warning("%d classes have 0 samples" % (count_empty)) self.utt_idx2class = class_idx self.class2utt_idx = class2utt_idx self.class2num_utt = class2num_utt if class_weights is not None: class_weights /= np.sum(class_weights) class_weights = torch.Tensor(class_weights) self.class_weights = class_weights if self.short_seq_exist: # if there are seq shorter than max_chunk_lenght we need some extra variables # we will need class_weights to put to 0 classes that have all utts shorter than the batch chunk length if self.class_weights is None: self.class_weights = torch.ones((self.num_classes,)) # we need the max length of the utterances of each class class2max_length = torch.zeros((self.num_classes,), dtype=torch.int) for c in range(self.num_classes): if class2num_utt[c] > 0: class2max_length[c] = int( np.max(self.seq_lengths[self.class2utt_idx[c]]) ) self.class2max_length = class2max_length def _seq_shorter_than_max_length_exists(self, max_length): return np.any(self.seq_lengths < max_length) @property def var_chunk_length(self): return self.min_chunk_length < self.max_chunk_length
[docs] def get_random_chunk_length(self): if self.var_chunk_length: return torch.randint( low=self.min_chunk_length, high=self.max_chunk_length + 1, size=(1,) ).item() return self.max_chunk_length
# def get_random_chunk_length(self, index): # if self.min_chunk_length < self.max_chunk_length: # if self.short_seq_exist: # max_chunk_length = min(int(np.min(self.seq_lengths[index])), # self.max_chunk_length) # else: # max_chunk_length = self.max_chunk_length # chunk_length = torch.randint( # low=self.min_chunk_length, high=max_chunk_length+1, size=(1,)).item() # # logging.info('{} {} {} set_random_chunk_length={}'.format( # # self,os.getpid(), threading.get_ident(), chunk_length)) # return chunk_length # return self.max_chunk_length def __getitem__(self, index): # logging.info('{} {} {} get item {}'.format( # self, os.getpid(), threading.get_ident(), index)) if self.return_fullseqs: return self._get_fullseq(index) else: return self._get_random_chunk(index) def _get_fullseq(self, index): key = self.u2c.key[index] x = self.r.read([key])[0].astype(floatstr_torch(), copy=False) if self.transpose_input: x = x.T if not self.return_class: return x class_idx = self.utt_idx2class[index] return x, class_idx def _get_random_chunk(self, index): if len(index) == 2: index, chunk_length = index else: chunk_length = self.max_chunk_length key = self.u2c.key[index] full_seq_length = int(self.seq_lengths[index]) assert ( chunk_length <= full_seq_length ), "chunk_length(%d) <= full_seq_length(%d)" % (chunk_length, full_seq_length) first_frame = torch.randint( low=0, high=full_seq_length - chunk_length + 1, size=(1,) ).item() x = self.r.read([key], row_offset=first_frame, num_rows=chunk_length)[0].astype( floatstr_torch(), copy=False ) if self.transpose_input: x = x.T if not self.return_class: return x class_idx = self.utt_idx2class[index] return x, class_idx
[docs] @staticmethod def filter_args(**kwargs): valid_args = ( "path_prefix", "class_file", "num_frames_file", "min_chunk_length", "max_chunk_length", "return_fullseqs", "part_idx", "num_parts", ) return dict((k, kwargs[k]) for k in valid_args if k in kwargs)
[docs] @staticmethod def add_class_args(parser, prefix=None): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") parser.add_argument( "--path-prefix", default="", help=("path prefix for rspecifier scp file") ) parser.add_argument( "--class-file", default=None, help=("ordered list of classes keys, it can contain class weights"), ) parser.add_argument( "--num-frames-file", default=None, help=( "utt to num_frames file, if None it reads from the dataset " "but it is slow" ), ) parser.add_argument( "--min-chunk-length", type=int, default=None, help=("minimum length of sequence chunks"), ) parser.add_argument( "--max-chunk-length", type=int, default=None, help=("maximum length of sequence chunks"), ) parser.add_argument( "--return-fullseqs", default=False, action="store_true", help=("returns full sequences instead of chunks"), ) # parser.add_argument('--part-idx', # type=int, default=1, # help=('splits the list of files in num-parts and process part_idx')) # parser.add_argument('--num-parts', # type=int, default=1, # help=('splits the list of files in num-parts and process part_idx')) if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='feature sequence dataset options') add_argparse_args = add_class_args