Source code for hyperion.utils.utt2info

"""
 Copyright 2018 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import os.path as path
import logging
from collections import OrderedDict
from copy import deepcopy

import numpy as np
import pandas as pd

from .list_utils import *


[docs]class Utt2Info(object): """Class to manipulate utt2spk, utt2lang, etc. files. Attributes: key: segment key name. info: key_to_index: Dictionary that returns the position of a key in the list. """
[docs] def __init__(self, utt_info): self.utt_info = utt_info self.validate() self.utt_info.index = self.utt_info.key self.key_to_index = None
[docs] def validate(self): """Validates the attributes of the Utt2Info object.""" assert "key" in self.utt_info.columns assert self.utt_info.shape[1] >= 2
# assert self.utt_info['key'].nunique() == self.utt_info.shape[0]
[docs] @classmethod def create(cls, key, info): key = np.asarray(key) info = np.asarray(info) if info.ndim == 2: data = np.hstack((key[:, None], info)) else: data = np.vstack((key, info)).T num_columns = data.shape[1] columns = ["key"] + [i for i in range(1, num_columns)] utt_info = pd.DataFrame(data, columns=columns) return cls(utt_info)
@property def num_info_fields(self): return self.utt_info.shape[1] - 1 @property def key(self): return np.asarray(self.utt_info["key"]) @property def info(self): if self.utt_info.shape[1] > 2: return np.asarray(self.utt_info.iloc[:, 1:]) else: return np.asarray(self.utt_info[1])
[docs] def copy(self): """Makes a copy of the object.""" return deepcopy(self)
[docs] def __len__(self): """Returns the number of elements in the list.""" return len(self.utt_info)
[docs] def len(self): """Returns the number of elements in the list.""" return len(self.utt_info)
[docs] def _create_dict(self): """Creates dictionary that returns the position of a segment in the list. """ self.key_to_index = OrderedDict( (k, i) for i, k in enumerate(self.utt_info.index) )
[docs] def get_index(self, key): """Returns the position of key in the list.""" if self.key_to_index is None: self._create_dict() return self.key_to_index[key]
[docs] def __contains__(self, key): """Returns True if the list contains the key""" return key in self.utt_info.index
[docs] def __getitem__(self, key): """It allows to acces the data in the list by key or index like in a ditionary, e.g.: If input is a string key: utt2spk = Utt2Info(info) spk_id = utt2spk['data1'] If input is an index: key, spk_id = utt2spk[0] Args: key: String key or integer index. Returns: If key is a string: info corresponding to key If key is the index in the key list: key, info given index """ if isinstance(key, str): row = np.array(self.utt_info.loc[key])[1:] if len(row) == 1: return row[0] else: return row else: row = np.array(self.utt_info.iloc[key]) if len(row) == 2: return row[0], row[1] else: return row[0], row[1:]
[docs] def sort(self, field=0): """Sorts the list by key""" if field == 0: self.utt_info.sort_index(ascending=True, inplace=True) else: idx = np.argsort(self.utt_info[field]) self.utt_info = self.utt_info.iloc[idx] self.key_to_index = None
[docs] def save(self, file_path, sep=" "): """Saves uttinfo to text file. Args: file_path: File to write the list. sep: Separator between the key and file_path in the text file. """ self.utt_info.to_csv(file_path, sep=sep, header=False, index=False)
[docs] @classmethod def load(cls, file_path, sep=" ", dtype={0: np.str, 1: np.str}): """Loads utt2info list from text file. Args: file_path: File to read the list. sep: Separator between the key and file_path in the text file. dtype: Dictionary with the dtypes of each column. Returns: Utt2Info object """ df = pd.read_csv(file_path, sep=sep, header=None, dtype=dtype) df = df.rename(index=str, columns={0: "key"}) return cls(df)
[docs] def split(self, idx, num_parts, group_by_field=0): """Splits SCPList into num_parts and return part idx. Args: idx: Part to return from 1 to num_parts. num_parts: Number of parts to split the list. group_by_field: All the lines with the same value in column groub_by_field go to the same part Returns: Sub Utt2Info object """ if group_by_field == 0: key, idx1 = split_list(self.utt_info["key"], idx, num_parts) else: key, idx1 = split_list_group_by_key( self.utt_info[group_by_field], idx, num_parts ) utt_info = self.utt_info.iloc[idx1] return Utt2Info(utt_info)
[docs] @classmethod def merge(cls, info_lists): """Merges several Utt2Info tables. Args: info_lists: List of Utt2Info Returns: Utt2Info object concatenation the info_lists. """ df_list = [u2i.utt_info for u2i in info_lists] utt_info = pd.concat(df_list) return cls(utt_info)
[docs] def filter(self, filter_key, keep=True): """Removes elements from Utt2Info object by key Args: filter_key: List with the keys of the elements to keep or remove. keep: If True, we keep the elements in filter_key; if False, we remove the elements in filter_key; Returns: Utt2Info object. """ if not keep: filter_key = np.setdiff1d(self.utt_info["key"], filter_key) utt_info = self.utt_info.loc[filter_key] return Utt2Info(utt_info)
[docs] def filter_info(self, filter_key, field=1, keep=True): """Removes elements of Utt2Info by info value Args: filter_key: List with the file_path of the elements to keep or remove. field: Field number corresponding to the info to filter keep: If True, we keep the elements in filter_key; if False, we remove the elements in filter_key; Returns: Utt2Info object. """ if not keep: filter_key = np.setdiff1d(self.utt_info[field], filter_key) f, _ = ismember(filter_key, self.utt_info[field]) if not np.all(f): for k in filter_key[f == False]: logging.error("info %s not found in field %d" % (k, field)) raise Exception("not all keys were found in field %d" % (field)) f, _ = ismember(self.utt_info[field], filter_key) utt_info = self.utt_info.iloc[f] return Utt2Info(utt_info)
[docs] def filter_index(self, index, keep=True): """Removes elements of Utt2Info by index Args: filter_key: List with the index of the elements to keep or remove. keep: If True, we keep the elements in filter_key; if False, we remove the elements in filter_key; Returns: Utt2Info object. """ if not keep: index = np.setdiff1d(np.arange(len(self.key), dtype=np.int64), index) utt_info = self.utt_info.iloc[index] return Utt2Info(utt_info)
[docs] def shuffle(self, seed=1024, rng=None): """Shuffles the elements of the list. Args: seed: Seed for random number generator. rng: numpy random number generator object. Returns: Index used to shuffle the list. """ if rng is None: rng = np.random.RandomState(seed=seed) index = np.arange(len(self.key)) rng.shuffle(index) self.utt_info = self.utt_info.iloc[index] self.key_to_index = None return index
[docs] def __eq__(self, other): """Equal operator""" if self.utt_info.shape[0] == 0 and other.utt_info.shape[0] == 0: return True eq = self.utt_info.equals(other.utt_info) return eq
[docs] def __ne__(self, other): """Non-equal operator""" return not self.__eq__(other)
[docs] def __cmp__(self, other): """Comparison operator""" if self.__eq__(other): return 0 return 1