Source code for hyperion.torch.data.weighted_embed_sampler

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
# import os
import math
import logging

import numpy as np

import torch
from torch.utils.data import Sampler


[docs]class ClassWeightedEmbedSampler(Sampler):
[docs] def __init__(self, dataset, batch_size=1, iters_per_epoch=1, num_egs_per_class=1): super().__init__(None) self.dataset = dataset self.batch_size = batch_size self.num_egs_per_class = num_egs_per_class self.batch = 0 self.iters_per_epoch = iters_per_epoch self._len = int(math.ceil(self.iters_per_epoch * len(dataset) / batch_size)) logging.info("num batches per epoch: %d" % self._len) self._num_classes_per_batch = int(math.ceil(batch_size / num_egs_per_class)) logging.info("num classes per batch: %d" % self._num_classes_per_batch)
def __len__(self): return self._len def __iter__(self): self.batch = 0 return self def _remove_duplicate_idx(self, utt_idx): utt_idx_uniq = torch.unique(utt_idx) c = 0 # we make 3 tries to remove duplicate utt idx delta = len(utt_idx) - len(utt_idx_uniq) while delta > 0 and c < 3: extra_idx = torch.randint(low=0, high=len(self.dataset), size=(delta,)) utt_idx = torch.cat((utt_idx_uniq, extra_idx)) utt_idx_uniq = torch.unique(utt_idx) delta = len(utt_idx) - len(utt_idx_uniq) c += 1 return utt_idx def _get_utt_idx(self): dataset = self.dataset num_classes_per_batch = self._num_classes_per_batch if dataset.class_weights is None: class_idx = torch.randint( low=0, high=dataset.num_classes, size=(num_classes_per_batch,) ) else: class_idx = torch.multinomial( dataset.class_weights, num_samples=num_classes_per_batch, replacement=True, ) if self.num_egs_per_class > 1: class_idx = class_idx.repeat(self.num_egs_per_class) utt_idx = torch.as_tensor( [ dataset.class2utt_idx[c][ torch.randint(low=0, high=int(dataset.class2num_utt[c]), size=(1,)) ] for c in class_idx.tolist() ] ) utt_idx = self._remove_duplicate_idx(utt_idx) return utt_idx def __next__(self): if self.batch == self._len: raise StopIteration utt_idx = self._get_utt_idx() if self.batch == 0: logging.info("batch 0 uttidx=%s", str(utt_idx[:10])) self.batch += 1 return utt_idx.tolist()