Source code for hyperion.utils.trial_stats

"""
 Copyright 2018 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""


import os.path as path
import logging
import copy

import numpy as np
import pandas as pd

from ..hyp_defs import float_cpu
from .trial_ndx import TrialNdx
from .trial_key import TrialKey


[docs]class TrialStats(object): """Contains anciliary statistics from the trial such us quality measures like SNR This class was created to store statistics about adversarial attacks like SNR (signal-to-perturbation ratio), Linf, L2 norms of the perturbation etc. Attributes: df_stats: pandas dataframe containing the stats. The dataframe needs to include the modelid and segmentid columns """
[docs] def __init__(self, df_stats): self.df_stats = df_stats assert "modelid" in df_stats.columns assert "segmentid" in df_stats.columns self.df_stats.set_index(["modelid", "segmentid"], inplace=True) self._stats_mats = dict()
[docs] @classmethod def load(cls, file_path): """Loads stats file Args: file_path: stats file in csv format Returns: TrialScores object. """ df = pd.read_csv(file_path) return cls(df)
[docs] def save_h5(self, file_path): """Saves object to file. Args: file_path: CSV format file """ self.df_stats.to_csv(file_path)
[docs] def get_stats_mat(self, stat_name, ndx, raise_missing=True): """Returns a matrix of trial statistics sorted to match a give Ndx or Key object Args: stat_name: name of the statatistic (e.g. snr, linf), as given in the column name of the dataframe. ndx: Ndx or Key object Returns: Stat matrix (n_models x n_tests) """ if stat_name in self._stats_mats: return self._stats_mats[stat_name] if isinstance(ndx, TrialKey): trial_mask = np.logical_or(ndx.tar, ndx.non) else: trial_mask = ndx.trial_mask stats_mat = np.zeros(trial_mask.shape, dtype=float_cpu()) for i in range(stats_mat.shape[0]): for j in range(stats_mat.shape[1]): if trial_mask[i, j]: try: stats_mat[i, j] = self.df_stats.loc[ ndx.model_set[i], ndx.seg_set[j] ][stat_name] except: err_str = "%s not found for %s-%s" % ( stat_name, ndx.model_set[i], ndx.seg_set[j], ) if raise_missing: raise Exception(err_str) else: logging.warning(err_str) self._stats_mats[stat_name] = stats_mat return stats_mat
[docs] def reset_stats_mats(self): for k in list(self._stats_mats.keys()): del self._stats_mats[k]