Source code for hyperion.utils.trial_key

"""
 Copyright 2018 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import os.path as path
import copy

import numpy as np
import h5py

from .list_utils import *
from .trial_ndx import TrialNdx


[docs]class TrialKey(object): """Contains the trial key for speaker recognition trials. Bosaris compatible Key. Attributes: model_set: List of model names. seg_set: List of test segment names. tar: Boolean matrix with target trials to True (num_models x num_segments). non: Boolean matrix with non-target trials to True (num_models x num_segments). model_cond: Conditions related to the model. seg_cond: Conditions related to the test segment. trial_cond: Conditions related to the combination of model and test segment. model_cond_name: String list with the names of the model conditions. seg_cond_name: String list with the names of the segment conditions. trial_cond_name: String list with the names of the trial conditions. """
[docs] def __init__( self, model_set=None, seg_set=None, tar=None, non=None, model_cond=None, seg_cond=None, trial_cond=None, model_cond_name=None, seg_cond_name=None, trial_cond_name=None, ): self.model_set = model_set self.seg_set = seg_set self.tar = tar self.non = non self.model_cond = model_cond self.seg_cond = seg_cond self.trial_cond = trial_cond self.model_cond_name = model_cond_name self.seg_cond_name = seg_cond_name self.trial_cond_name = trial_cond_name if (model_set is not None) and (seg_set is not None): self.validate()
@property def num_models(self): return len(self.model_set) @property def num_tests(self): return len(self.seg_set)
[docs] def copy(self): """Makes a copy of the object""" return copy.deepcopy(self)
[docs] def sort(self): """Sorts the object by model and test segment names.""" self.model_set, m_idx = sort(self.model_set, return_index=True) self.seg_set, s_idx = sort(self.seg_set, return_index=True) ix = np.ix_(m_idx, s_idx) self.tar = self.tar[ix] self.non = self.non[ix] if self.model_cond is not None: self.model_cond = self.model_cond[m_idx] if self.seg_cond is not None: self.seg_cond = self.seg_cond[s_idx] if self.trial_cond is not None: self.trial_cond = self.trial_cond[:, ix]
[docs] def save(self, file_path): """Saves object to txt/h5 file. Args: file_path: File to write the list. """ file_base, file_ext = path.splitext(file_path) if file_ext == ".h5" or file_ext == ".hdf5": self.save_h5(file_path) else: self.save_txt(file_path)
[docs] def save_h5(self, file_path): """Saves object to h5 file. Args: file_path: File to write the list. """ with h5py.File(file_path, "w") as f: model_set = self.model_set.astype("S") seg_set = self.seg_set.astype("S") f.create_dataset("ID/row_ids", data=model_set) f.create_dataset("ID/column_ids", data=seg_set) trial_mask = self.tar.astype("int8") - self.non.astype("int8") f.create_dataset("trial_mask", data=trial_mask) if self.model_cond is not None: f.create_dataset("model_cond", data=self.model_cond.astype("uint8")) if self.seg_cond is not None: f.create_dataset("seg_cond", data=self.seg_cond.astype("uint8")) if self.trial_cond is not None: f.create_dataset("trial_cond", data=self.trial_cond.astype("uint8")) if self.model_cond_name is not None: model_cond_name = self.model_cond_name.astype("S") f.create_dataset("model_cond_name", data=model_cond_name) if self.seg_cond_name is not None: seg_cond_name = self.seg_cond_name.astype("S") f.create_dataset("seg_cond_name", data=seg_cond_name) if self.trial_cond_name is not None: trial_cond_name = self.trial_cond_name.astype("S") f.create_dataset("trial_cond_name", data=trial_cond_name)
[docs] def save_txt(self, file_path): """Saves object to txt file. Args: file_path: File to write the list. """ with open(file_path, "w") as f: idx = (self.tar.T == True).nonzero() for item in zip(idx[0], idx[1]): f.write( "%s %s target\n" % (self.model_set[item[1]], self.seg_set[item[0]]) ) idx = (self.non.T == True).nonzero() for item in zip(idx[0], idx[1]): f.write( "%s %s nontarget\n" % (self.model_set[item[1]], self.seg_set[item[0]]) )
[docs] @classmethod def load(cls, file_path): """Loads object from txt/h5 file Args: file_path: File to read the list. Returns: TrialKey object. """ file_base, file_ext = path.splitext(file_path) if file_ext == ".h5" or file_ext == ".hdf5": return cls.load_h5(file_path) else: return cls.load_txt(file_path)
[docs] @classmethod def load_h5(cls, file_path): """Loads object from h5 file Args: file_path: File to read the list. Returns: TrialKey object. """ with h5py.File(file_path, "r") as f: model_set = [t.decode("utf-8") for t in f["ID/row_ids"]] seg_set = [t.decode("utf-8") for t in f["ID/column_ids"]] trial_mask = np.asarray(f["trial_mask"], dtype="int8") tar = (trial_mask > 0).astype("bool") non = (trial_mask < 0).astype("bool") model_cond = None seg_cond = None trial_cond = None model_cond_name = None seg_cond_name = None trial_cond_name = None if "model_cond" in f: model_cond = np.asarray(f["model_cond"], dtype="bool") if "seg_cond" in f: seg_cond = np.asarray(f["seg_cond"], dtype="bool") if "trial_cond" in f: trial_cond = np.asarray(f["trial_cond"], dtype="bool") if "model_cond_name" in f: model_cond_name = np.asarray(f["model_cond_name"], dtype="U") if "seg_cond_name" in f: seg_cond_name = np.asarray(f["seg_cond_name"], dtype="U") if "trial_cond_name" in f: trial_cond_name = np.asarray(f["trial_cond_name"], dtype="U") return cls( model_set, seg_set, tar, non, model_cond, seg_cond, trial_cond, model_cond_name, seg_cond_name, trial_cond_name, )
[docs] @classmethod def load_txt(cls, file_path): """Loads object from txt file Args: file_path: File to read the list. Returns: TrialKey object. """ with open(file_path, "r") as f: fields = [line.split() for line in f] models = [i[0] for i in fields] segments = [i[1] for i in fields] is_tar = [i[2] == "target" for i in fields] model_set, _, model_idx = np.unique( models, return_index=True, return_inverse=True ) seg_set, _, seg_idx = np.unique( segments, return_index=True, return_inverse=True ) tar = np.zeros((len(model_set), len(seg_set)), dtype="bool") non = np.zeros((len(model_set), len(seg_set)), dtype="bool") for item in zip(model_idx, seg_idx, is_tar): if item[2]: tar[item[0], item[1]] = True else: non[item[0], item[1]] = True return cls(model_set, seg_set, tar, non)
[docs] @classmethod def merge(cls, key_list): """Merges several key objects. Args: key_list: List of TrialKey objects. Returns: Merged TrialKey object. """ num_key = len(key_list) model_set = key_list[0].model_set seg_set = key_list[0].seg_set tar = key_list[0].tar non = key_list[0].non model_cond = key_list[0].model_cond seg_cond = key_list[0].seg_cond trial_cond = key_list[0].trial_cond if model_cond is not None: num_model_cond = model_cond.shape[0] if seg_cond is not None: num_seg_cond = seg_cond.shape[0] if trial_cond is not None: num_trial_cond = trial_cond.shape[0] for i in range(1, num_key): key_i = key_list[i] new_model_set = np.union1d(model_set, key_i.model_set) new_seg_set = np.union1d(seg_set, key_i.seg_set) shape = (len(new_model_set), len(new_seg_set)) _, mi_a, mi_b = intersect( new_model_set, model_set, assume_unique=True, return_index=True ) _, si_a, si_b = intersect( new_seg_set, seg_set, assume_unique=True, return_index=True ) ix_a = np.ix_(mi_a, si_a) ix_b = np.ix_(mi_b, si_b) tar_1 = np.zeros(shape, dtype="bool") tar_1[ix_a] = tar[ix_b] non_1 = np.zeros(shape, dtype="bool") non_1[ix_a] = non[ix_b] if model_cond is not None: model_cond_1 = np.zeros((num_model_cond, shape[0]), dtype="bool") model_cond_1[:, mi_a] = model_cond[:, mi_b] if seg_cond is not None: seg_cond_1 = np.zeros((num_seg_cond, shape[0]), dtype="bool") seg_cond_1[:, mi_a] = seg_cond[:, mi_b] if trial_cond is not None: trial_cond_1 = np.zeros((num_trial_cond, shape), dtype="bool") trial_cond_1[:, ix_a] = trial_cond[:, ix_b] _, mi_a, mi_b = intersect( new_model_set, key_i.model_set, assume_unique=True, return_index=True ) _, si_a, si_b = intersect( new_seg_set, key_i.seg_set, assume_unique=True, return_index=True ) ix_a = np.ix_(mi_a, si_a) ix_b = np.ix_(mi_b, si_b) tar_2 = np.zeros(shape, dtype="bool") tar_2[ix_a] = key_i.tar[ix_b] non_2 = np.zeros(shape, dtype="bool") non_2[ix_a] = key_i.non[ix_b] if model_cond is not None: model_cond_2 = np.zeros((num_model_cond, shape[0]), dtype="bool") model_cond_2[:, mi_a] = key_i.model_cond[:, mi_b] if seg_cond is not None: seg_cond_2 = np.zeros((num_seg_cond, shape[0]), dtype="bool") seg_cond_2[:, mi_a] = key_i.seg_cond[:, mi_b] if trial_cond is not None: trial_cond_2 = np.zeros((num_trial_cond, shape), dtype="bool") trial_cond_2[:, ix_a] = key_i.trial_cond[:, ix_b] model_set = new_model_set seg_set = new_seg_set tar = np.logical_or(tar_1, tar_2) non = np.logical_or(non_1, non_2) if model_cond is not None: model_cond = np.logical_or(model_cond_1, model_cond_2) if seg_cond is not None: seg_cond = np.logical_or(seg_cond_1, seg_cond_2) if trial_cond is not None: trial_cond = np.logical_or(trial_cond_1, seg_cond_2) return cls( model_set, seg_set, tar, non, model_cond, seg_cond, trial_cond, key_list[0].model_cond_name, key_list[0].seg_cond_name, key_list[0].trial_cond_name, )
[docs] def filter(self, model_set, seg_set, keep=True): """Removes elements from TrialKey object. Args: model_set: List of models to keep or remove. seg_set: List of test segments to keep or remove. keep: If True, we keep the elements in model_set/seg_set, if False, we remove the elements in model_set/seg_set. Returns: Filtered TrialKey object. """ if not (keep): model_set = np.setdiff1d(self.model_set, model_set) seg_set = np.setdiff1d(self.seg_set, seg_set) f, mod_idx = ismember(model_set, self.model_set) assert np.all(f) f, seg_idx = ismember(seg_set, self.seg_set) assert np.all(f) model_set = self.model_set[mod_idx] set_set = self.seg_set[seg_idx] ix = np.ix_(mod_idx, seg_idx) tar = self.tar[ix] non = self.non[ix] model_cond = None seg_cond = None trial_cond = None if self.model_cond is not None: model_cond = self.model_cond[:, mod_idx] if self.seg_cond is not None: seg_cond = self.seg_cond[:, seg_idx] if self.trial_cond is not None: trial_cond = self.trial_cond[:, ix] return TrialKey( model_set, seg_set, tar, non, model_cond, seg_cond, trial_cond, self.model_cond_name, self.seg_cond_name, self.trial_cond_name, )
[docs] def split(self, model_idx, num_model_parts, seg_idx, num_seg_parts): """Splits the TrialKey into num_model_parts x num_seg_parts and returns part (model_idx, seg_idx). Args: model_idx: Model index of the part to return from 1 to num_model_parts. num_model_parts: Number of parts to split the model list. seg_idx: Segment index of the part to return from 1 to num_model_parts. num_seg_parts: Number of parts to split the test segment list. Returns: Subpart of the TrialKey """ model_set, model_idx1 = split_list(self.model_set, model_idx, num_model_parts) seg_set, seg_idx1 = split_list(self.seg_set, seg_idx, num_seg_parts) ix = np.ix_(model_idx1, seg_idx1) tar = self.tar[ix] non = self.non[ix] model_cond = None seg_cond = None trial_cond = None if self.model_cond is not None: model_cond = self.model_cond[:, model_idx1] if self.seg_cond is not None: seg_cond = self.seg_cond[:, seg_idx1] if self.trial_cond is not None: trial_cond = self.trial_cond[:, ix] return TrialKey( model_set, seg_set, tar, non, model_cond, seg_cond, trial_cond, self.model_cond_name, self.seg_cond_name, self.trial_cond_name, )
[docs] def to_ndx(self): """Converts TrialKey object into TrialNdx object. Returns: TrialNdx object. """ mask = np.logical_or(self.tar, self.non) return TrialNdx(self.model_set, self.seg_set, mask)
[docs] def validate(self): """Validates the attributes of the TrialKey object.""" self.model_set = list2ndarray(self.model_set) self.seg_set = list2ndarray(self.seg_set) shape = (len(self.model_set), len(self.seg_set)) assert len(np.unique(self.model_set)) == shape[0] assert len(np.unique(self.seg_set)) == shape[1] if (self.tar is None) or (self.non is None): self.tar = np.zeros(shape, dtype="bool") self.non = np.zeros(shape, dtype="bool") else: assert self.tar.shape == shape assert self.non.shape == shape if self.model_cond is not None: assert self.model_cond.shape[1] == shape[0] if self.seg_cond is not None: assert self.seg_cond.shape[1] == shape[1] if self.trial_cond is not None: assert self.trial_cond.shape[1:] == shape if self.model_cond_name is not None: self.model_cond_name = list2ndarray(self.model_cond_name) if self.seg_cond_name is not None: self.seg_cond_name = list2ndarray(self.seg_cond_name) if self.trial_cond_name is not None: self.trial_cond_name = list2ndarray(self.trial_cond_name)
[docs] def __eq__(self, other): """Equal operator""" eq = self.model_set.shape == other.model_set.shape eq = eq and np.all(self.model_set == other.model_set) eq = eq and (self.seg_set.shape == other.seg_set.shape) eq = eq and np.all(self.seg_set == other.seg_set) eq = eq and np.all(self.tar == other.tar) eq = eq and np.all(self.non == other.non) eq = eq and ((self.model_cond is None) == (other.model_cond is None)) eq = eq and ((self.seg_cond is None) == (other.seg_cond is None)) eq = eq and ((self.trial_cond is None) == (other.trial_cond is None)) if self.model_cond is not None: eq = eq and np.all(self.model_cond == other.model_cond) if self.seg_cond is not None: eq = eq and np.all(self.seg_cond == other.seg_cond) if self.trial_cond is not None: eq = eq and np.all(self.triall_cond == other.trial_cond) eq = eq and ((self.model_cond_name is None) == (other.model_cond_name is None)) eq = eq and ((self.seg_cond_name is None) == (other.seg_cond_name is None)) eq = eq and ((self.trial_cond_name is None) == (other.trial_cond_name is None)) if self.model_cond_name is not None: eq = eq and np.all(self.model_cond_name == other.model_cond_name) if self.seg_cond_name is not None: eq = eq and np.all(self.seg_cond_name == other.seg_cond_name) if self.trial_cond_name is not None: eq = eq and np.all(self.triall_cond_name == other.trial_cond_name) return eq
[docs] def __ne__(self, other): """Non-equal operator""" return not self.__eq__(other)
[docs] def __cmp__(self, other): """Comparison operator""" if self.__eq__(other): return 0 return 1
[docs] def test(key_file="core-core_det5_key.h5"): key1 = TrialKey.load(key_file) key1.sort() key2 = key1.copy() key2.model_set[0] = "m1" key2.tar[:] = 0 assert np.any(key1.model_set != key2.model_set) assert np.any(key1.tar != key2.tar) key2 = TrialKey( key1.model_set[:10], key1.seg_set, key1.tar[:10, :], key1.non[:10, :] ) key3 = TrialKey( key1.model_set[5:], key1.seg_set, key1.tar[5:, :], key1.non[5:, :] ) key4 = TrialKey.merge([key2, key3]) assert key1 == key4 key2 = TrialKey( key1.model_set, key1.seg_set[:10], key1.tar[:, :10], key1.non[:, :10] ) key3 = TrialKey( key1.model_set, key1.seg_set[5:], key1.tar[:, 5:], key1.non[:, 5:] ) key4 = TrialKey.merge([key2, key3]) assert key1 == key4 key2 = TrialKey( key1.model_set[:5], key1.seg_set[:10], key1.tar[:5, :10], key1.non[:5, :10] ) key3 = key1.filter(key2.model_set, key2.seg_set, keep=True) assert key2 == key3 num_parts = 3 key_list = [] for i in range(num_parts): for j in range(num_parts): key_ij = key1.split(i + 1, num_parts, j + 1, num_parts) key_list.append(key_ij) key2 = TrialKey.merge(key_list) assert key1 == key2 ndx1 = key1.to_ndx() ndx1.validate() file_h5 = "test.h5" key1.save(file_h5) key3 = TrialKey.load(file_h5) assert key1 == key2 file_txt = "test.txt" key3.tar[0, :] = True key3.non[:, 0] = True key3.save(file_txt) key2 = TrialKey.load(file_txt) assert key3 == key2