Source code for hyperion.torch.data.embed_dataset

"""
 Copyright 2021 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

# import sys
# import os
import logging
import time

# import copy

import numpy as np
import pandas as pd

import torch

from ..torch_defs import floatstr_torch
from ...io import RandomAccessDataReaderFactory as RF
from ...utils.utt2info import Utt2Info

from torch.utils.data import Dataset


[docs]class EmbedDataset(Dataset):
[docs] def __init__( self, embeds=None, class_ids=None, class_weights=None, rspecifier=None, key_file=None, class_file=None, path_prefix=None, preload_embeds=False, return_class=True, is_val=False, ): assert embeds is not None or rspecifier is not None assert rspecifier is None or key_file is not None assert class_ids is not None or key_file is not None self.preload_embeds = preload_embeds if key_file is not None: if isinstance(key_file, Utt2Info): self.u2c = key_file else: logging.info("loading utt2info file %s", key_file) self.u2c = Utt2Info.load(key_file, sep=" ") self.num_embeds = len(self.u2c) else: assert embeds is not None self.u2c = None self.num_embeds = len(embeds) if embeds is None: logging.info("opening dataset %s", rspecifier) self.r = RF.create(rspecifier, path_prefix=path_prefix, scp_sep=" ") if self.preload_embeds: self.embeds = self.r.load(u2c.key, squeeze=True).astype( floatstr_torch(), copy=False ) del self.r self.r = None else: self.preload_embeds = True self.embeds = embeds.astype(floatstr_torch(), copy=False) self.is_val = is_val self._prepare_class_info(class_file, class_ids, class_weights) self.return_class = return_class logging.info("dataset contains %d embeds", self.num_embeds)
def __len__(self): return self.num_embeds def _prepare_class_info(self, class_file, class_idx=None, class_weights=None): if class_file is None: if self.u2c is not None: classes, class_idx = np.unique(self.u2c.info, return_inverse=True) self.num_classes = np.max(class_idx) + 1 else: logging.info("reading class-file %s", class_file) class_info = pd.read_csv(class_file, header=None, sep=" ") class2idx = {str(k): i for i, k in enumerate(class_info[0])} self.num_classes = len(class2idx) class_idx = np.array([class2idx[k] for k in self.u2c.info], dtype=int) if class_info.shape[1] == 2: class_weights = np.array(class_info[1]).astype( floatstr_torch(), copy=False ) class2utt_idx = {} class2num_utt = np.zeros((self.num_classes,), dtype=int) for k in range(self.num_classes): idx = (class_idx == k).nonzero()[0] class2utt_idx[k] = idx class2num_utt[k] = len(idx) if class2num_utt[k] == 0: if not self.is_val: logging.warning("class %d doesn't have any samples", k) if class_weights is None: class_weights = np.ones((self.num_classes,), dtype=floatstr_torch()) class_weights[k] = 0 count_empty = np.sum(class2num_utt == 0) if count_empty > 0: logging.warning("%d classes have 0 samples", count_empty) self.utt_idx2class = class_idx self.class2utt_idx = class2utt_idx self.class2num_utt = class2num_utt if class_weights is not None: class_weights /= np.sum(class_weights) class_weights = torch.Tensor(class_weights) self.class_weights = class_weights def __getitem__(self, index): if self.preload_embeds: x = self.embeds[index] else: key = self.u2c.key[index] x = self.r.read([key])[0].astype(floatstr_torch(), copy=False) if not self.return_class: return x class_idx = self.utt_idx2class[index] return x, class_idx