"""
Copyright 2018 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import sys
import os
import argparse
import time
import copy
import numpy as np
from ..io import RandomAccessDataReaderFactory as DRF
from ..utils.utt2info import Utt2Info
from ..utils import TrialNdx, TrialKey
from ..transforms import TransformList
[docs]class TrialDataReader(object):
"""
Loads Ndx, enroll file and x-vectors to evaluate PLDA.
"""
[docs] def __init__(
self,
v_file,
ndx_file,
enroll_file,
test_file=None,
preproc=None,
model_part_idx=1,
num_model_parts=1,
seg_part_idx=1,
num_seg_parts=1,
eval_set="enroll-test",
tlist_sep=" ",
):
self.r = DRF.create(v_file)
self.preproc = preproc
enroll = Utt2Info.load(enroll_file, sep=tlist_sep)
test = None
if test_file is not None:
test = Utt2Info.load(test_file, sep=tlist_sep)
ndx = None
if ndx_file is not None:
try:
ndx = TrialNdx.load(ndx_file)
except:
ndx = TrialKey.load(ndx_file).to_ndx()
ndx, enroll = TrialNdx.parse_eval_set(ndx, enroll, test, eval_set)
if num_model_parts > 1 or num_seg_parts > 1:
ndx = ndx.split(
model_part_idx, num_model_parts, seg_part_idx, num_seg_parts
)
enroll = enroll.filter_info(ndx.model_set)
self.enroll = enroll
self.ndx = ndx
[docs] def read(self):
x_e = self.r.read(self.enroll.key, squeeze=True)
x_t = self.r.read(self.ndx.seg_set, squeeze=True)
if self.preproc is not None:
x_e = self.preproc.predict(x_e)
x_t = self.preproc.predict(x_t)
return x_e, x_t, self.enroll.info, self.ndx
[docs] @staticmethod
def filter_args(**kwargs):
valid_args = (
"tlist_sep",
"model_idx",
"num_model_parts",
"seg_idx",
"num_seg_parts",
"eval_set",
)
return dict((k, kwargs[k]) for k in valid_args if k in kwargs)
[docs] @staticmethod
def add_class_args(parser, prefix=None):
if prefix is None:
p1 = "--"
else:
p1 = "--" + prefix + "."
parser.add_argument(
p1 + "tlist-sep", default=" ", help=("trial lists field separator")
)
parser.add_argument(
p1 + "model-part-idx", default=1, type=int, help=("model part index")
)
parser.add_argument(
p1 + "num-model-parts",
default=1,
type=int,
help=(
"number of parts in which we divide the model"
"list to run evaluation in parallel"
),
)
parser.add_argument(
p1 + "seg-part-idx", default=1, type=int, help=("test part index")
)
parser.add_argument(
p1 + "num-seg-parts",
default=1,
type=int,
help=(
"number of parts in which we divide the test list "
"to run evaluation in parallel"
),
)
parser.add_argument(
p1 + "eval-set",
type=str.lower,
default="enroll-test",
choices=["enroll-test", "enroll-coh", "coh-test", "coh-coh"],
help=("evaluation subset"),
)
add_argparse_args = add_class_args