Source code for hyperion.utils.trial_ndx

"""
 Copyright 2018 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import os.path as path
import copy

import numpy as np
import h5py

from .list_utils import *


[docs]class TrialNdx(object): """Contains the trial index to run speaker recognition trials. Bosaris compatible Ndx. Attributes: model_set: List of model names. seg_set: List of test segment names. trial_mask: Boolean matrix with the trials to execute to True (num_models x num_segments). """
[docs] def __init__(self, model_set=None, seg_set=None, trial_mask=None): self.model_set = model_set self.seg_set = seg_set self.trial_mask = trial_mask if (model_set is not None) and (seg_set is not None): self.validate()
@property def num_models(self): return len(self.model_set) @property def num_tests(self): return len(self.seg_set)
[docs] def copy(self): """Makes a copy of the object""" return copy.deepcopy(self)
[docs] def sort(self): """Sorts the object by model and test segment names.""" self.model_set, m_idx = sort(self.model_set, return_index=True) self.seg_set, s_idx = sort(self.seg_set, return_index=True) self.trial_mask = self.trial_mask[np.ix_(m_idx, s_idx)]
[docs] def save(self, file_path): """Saves object to txt/h5 file. Args: file_path: File to write the list. """ file_base, file_ext = path.splitext(file_path) if file_ext == ".h5" or file_ext == ".hdf5": self.save_h5(file_path) else: self.save_txt(file_path)
[docs] def save_h5(self, file_path): """Saves object to h5 file. Args: file_path: File to write the list. """ with h5py.File(file_path, "w") as f: model_set = self.model_set.astype("S") seg_set = self.seg_set.astype("S") f.create_dataset("ID/row_ids", data=model_set) f.create_dataset("ID/column_ids", data=seg_set) f.create_dataset("trial_mask", data=self.trial_mask.astype("uint8"))
# model_set = self.model_set.astype('S') # f.create_dataset('ID/row_ids', self.model_set.shape, dtype=model_set.dtype) # f['ID/row_ids'] = model_set # seg_set = self.seg_set.astype('S') # f.create_dataset('ID/column_ids', self.seg_set.shape, dtype=seg_set.dtype) # f['ID/column_ids'] = seg_set # f.create_dataset('trial_mask', self.trial_mask.shape, dtype='uint8') # f['trial_mask'] = self.trial_mask.astype('uint8')
[docs] def save_txt(self, file_path): """Saves object to txt file. Args: file_path: File to write the list. """ idx = (self.trial_mask.T == True).nonzero() with open(file_path, "w") as f: for item in zip(idx[0], idx[1]): f.write("%s %s\n" % (self.model_set[item[1]], self.seg_set[item[0]]))
[docs] @classmethod def load(cls, file_path): """Loads object from txt/h5 file Args: file_path: File to read the list. Returns: TrialNdx object. """ file_base, file_ext = path.splitext(file_path) if file_ext == ".h5" or file_ext == ".hdf5": return cls.load_h5(file_path) else: return cls.load_txt(file_path)
[docs] @classmethod def load_h5(cls, file_path): """Loads object from h5 file Args: file_path: File to read the list. Returns: TrialNdx object. """ with h5py.File(file_path, "r") as f: model_set = [t.decode("utf-8") for t in f["ID/row_ids"]] seg_set = [t.decode("utf-8") for t in f["ID/column_ids"]] trial_mask = np.asarray(f["trial_mask"], dtype="bool") return cls(model_set, seg_set, trial_mask)
[docs] @classmethod def load_txt(cls, file_path): """Loads object from txt file Args: file_path: File to read the list. Returns: TrialNdx object. """ with open(file_path, "r") as f: fields = [line.split() for line in f] models = [i[0] for i in fields] segments = [i[1] for i in fields] model_set, _, model_idx = np.unique( models, return_index=True, return_inverse=True ) seg_set, _, seg_idx = np.unique( segments, return_index=True, return_inverse=True ) trial_mask = np.zeros((len(model_set), len(seg_set)), dtype="bool") for item in zip(model_idx, seg_idx): trial_mask[item[0], item[1]] = True return cls(model_set, seg_set, trial_mask)
[docs] @classmethod def merge(cls, ndx_list): """Merges several index objects. Args: key_list: List of TrialNdx objects. Returns: Merged TrialNdx object. """ num_ndx = len(ndx_list) model_set = ndx_list[0].model_set seg_set = ndx_list[0].seg_set trial_mask = ndx_list[0].trial_mask for i in range(1, num_ndx): ndx_i = ndx_list[i] new_model_set = np.union1d(model_set, ndx_i.model_set) new_seg_set = np.union1d(seg_set, ndx_i.seg_set) trial_mask_1 = np.zeros( (len(new_model_set), len(new_seg_set)), dtype="bool" ) _, mi_a, mi_b = intersect( new_model_set, model_set, assume_unique=True, return_index=True ) _, si_a, si_b = intersect( new_seg_set, seg_set, assume_unique=True, return_index=True ) trial_mask_1[np.ix_(mi_a, si_a)] = trial_mask[np.ix_(mi_b, si_b)] trial_mask_2 = np.zeros( (len(new_model_set), len(new_seg_set)), dtype="bool" ) _, mi_a, mi_b = intersect( new_model_set, ndx_i.model_set, assume_unique=True, return_index=True ) _, si_a, si_b = intersect( new_seg_set, ndx_i.seg_set, assume_unique=True, return_index=True ) trial_mask_2[np.ix_(mi_a, si_a)] = ndx_i.trial_mask[np.ix_(mi_b, si_b)] model_set = new_model_set seg_set = new_seg_set trial_mask = np.logical_or(trial_mask_1, trial_mask_2) return cls(model_set, seg_set, trial_mask)
[docs] @staticmethod def parse_eval_set(ndx, enroll, test=None, eval_set="enroll-test"): """Prepares the data structures required for evaluation. Args: ndx: TrialNdx object cotaining the trials for the main evaluation. enroll: Utt2Info where key are file_ids and second column are model names test: Utt2Info of where key are test segments names. Needed in the cases enroll-coh and coh-coh. eval_test: Type of of evaluation enroll-test: main evaluation of enrollment vs test segments. enroll-coh: enrollment vs cohort segments. coh-test: cohort vs test segments. coh-coh: cohort vs cohort segments. Return: ndx: TrialNdx object enroll: SCPList """ if eval_set == "enroll-test": enroll = enroll.filter_info(ndx.model_set) if eval_set == "enroll-coh": ndx = TrialNdx(ndx.model_set, test.file_path) enroll = enroll.filter_info(ndx.model_set) if eval_set == "coh-test": ndx = TrialNdx(enroll.key, ndx.seg_set) if eval_set == "coh-coh": ndx = TrialNdx(enroll.key, test.file_path) return ndx, enroll
[docs] def filter(self, model_set, seg_set, keep=True): """Removes elements from TrialNdx object. Args: model_set: List of models to keep or remove. seg_set: List of test segments to keep or remove. keep: If True, we keep the elements in model_set/seg_set, if False, we remove the elements in model_set/seg_set. Returns: Filtered TrialNdx object. """ if not (keep): model_set = np.setdiff1d(self.model_set, model_set) seg_set = np.setdiff1d(self.seg_set, seg_set) f, mod_idx = ismember(model_set, self.model_set) assert np.all(f) f, seg_idx = ismember(seg_set, self.seg_set) assert np.all(f) model_set = self.model_set[mod_idx] set_set = self.seg_set[seg_idx] trial_mask = self.trial_mask[np.ix_(mod_idx, seg_idx)] return TrialNdx(model_set, seg_set, trial_mask)
[docs] def split(self, model_idx, num_model_parts, seg_idx, num_seg_parts): """Splits the TrialNdx into num_model_parts x num_seg_parts and returns part (model_idx, seg_idx). Args: model_idx: Model index of the part to return from 1 to num_model_parts. num_model_parts: Number of parts to split the model list. seg_idx: Segment index of the part to return from 1 to num_model_parts. num_seg_parts: Number of parts to split the test segment list. Returns: Subpart of the TrialNdx """ model_set, model_idx1 = split_list(self.model_set, model_idx, num_model_parts) seg_set, seg_idx1 = split_list(self.seg_set, seg_idx, num_seg_parts) trial_mask = self.trial_mask[np.ix_(model_idx1, seg_idx1)] return TrialNdx(model_set, seg_set, trial_mask)
[docs] def validate(self): """Validates the attributes of the TrialKey object.""" self.model_set = list2ndarray(self.model_set) self.seg_set = list2ndarray(self.seg_set) assert len(np.unique(self.model_set)) == len(self.model_set) assert len(np.unique(self.seg_set)) == len(self.seg_set) if self.trial_mask is None: self.trial_mask = np.ones( (len(self.model_set), len(self.seg_set)), dtype="bool" ) else: assert self.trial_mask.shape == (len(self.model_set), len(self.seg_set))
[docs] def apply_segmentation_to_test(self, segment_list): """Splits test segment into multiple sub-segments Useful to create ndx for spk diarization or tracking. Args: segment_list: ExtSegmentList object with mapping of file_id to ext_segment_id Returns: New TrialNdx object with segment_ids in test instead of file_id. """ new_segset = [] new_mask = [] for i in range(self.num_tests): file_id = self.seg_set[i] segment_ids = segment_list.ext_segment_ids_from_file(file_id) new_segset.append(segment_ids) new_mask.append( np.repeat(self.trial_mask[:, i, None], len(segment_ids), axis=1) ) new_segset = np.concatenate(tuple(new_segset)) new_mask = np.concatenate(tuple(new_mask), axis=-1) return TrialNdx(self.model_set, new_segset, new_mask)
[docs] def __eq__(self, other): """Equal operator""" eq = self.model_set.shape == other.model_set.shape eq = eq and np.all(self.model_set == other.model_set) eq = eq and (self.seg_set.shape == other.seg_set.shape) eq = eq and np.all(self.seg_set == other.seg_set) eq = eq and np.all(self.trial_mask == other.trial_mask) return eq
[docs] def __ne__(self, other): """Non-equal operator""" return not self.__eq__(other)
[docs] def __cmp__(self, other): """Comparison operator""" if self.__eq__(oher): return 0 return 1
[docs] def test(ndx_file="core-core_det5_ndx.h5"): ndx1 = TrialNdx.load(ndx_file) ndx1.sort() ndx2 = ndx1.copy() ndx2.model_set[0] = "m1" ndx2.trial_mask[:] = 0 assert np.any(ndx1.model_set != ndx2.model_set) assert np.any(ndx1.trial_mask != ndx2.trial_mask) ndx2 = TrialNdx(ndx1.model_set[:10], ndx1.seg_set, ndx1.trial_mask[:10, :]) ndx3 = TrialNdx(ndx1.model_set[5:], ndx1.seg_set, ndx1.trial_mask[5:, :]) ndx4 = TrialNdx.merge([ndx2, ndx3]) assert ndx1 == ndx4 ndx2 = TrialNdx(ndx1.model_set, ndx1.seg_set[:10], ndx1.trial_mask[:, :10]) ndx3 = TrialNdx(ndx1.model_set, ndx1.seg_set[5:], ndx1.trial_mask[:, 5:]) ndx4 = TrialNdx.merge([ndx2, ndx3]) assert ndx1 == ndx4 ndx2 = TrialNdx(ndx1.model_set[:5], ndx1.seg_set[:10], ndx1.trial_mask[:5, :10]) ndx3 = ndx1.filter(ndx2.model_set, ndx2.seg_set, keep=True) assert ndx2 == ndx3 num_parts = 3 ndx_list = [] for i in range(num_parts): for j in range(num_parts): ndx_ij = ndx1.split(i + 1, num_parts, j + 1, num_parts) ndx_list.append(ndx_ij) ndx2 = TrialNdx.merge(ndx_list) assert ndx1 == ndx2 file_h5 = "test.h5" ndx1.save(file_h5) ndx2 = TrialNdx.load(file_h5) assert ndx1 == ndx2 file_txt = "test.txt" ndx3.trial_mask[0, :] = True ndx3.trial_mask[:, 0] = True ndx3.save(file_txt) ndx2 = TrialNdx.load(file_txt) assert ndx3 == ndx2