Source code for hyperion.io.data_rw_factory

"""
 Copyright 2018 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import logging
from jsonargparse import ArgumentParser, ActionParser

from ..utils.kaldi_matrix import compression_methods
from .rw_specifiers import ArchiveType, WSpecifier, RSpecifier, WSpecType, RSpecType
from .h5_data_writer import H5DataWriter as H5DW
from .ark_data_writer import ArkDataWriter as ADW
from .ark_data_reader import SequentialArkFileDataReader as SAFDR
from .ark_data_reader import SequentialArkScriptDataReader as SASDR
from .ark_data_reader import RandomAccessArkDataReader as RADR
from .h5_data_writer import H5DataWriter as H5DW
from .h5_data_reader import SequentialH5FileDataReader as SH5FDR
from .h5_data_reader import SequentialH5ScriptDataReader as SH5SDR
from .h5_data_reader import RandomAccessH5FileDataReader as RH5FDR
from .h5_data_reader import RandomAccessH5ScriptDataReader as RH5SDR


[docs]class DataWriterFactory(object): """ Class to create object that write data to hdf5/ark files. """
[docs] @staticmethod def create(wspecifier, compress=False, compression_method="auto", scp_sep=" "): if isinstance(wspecifier, str): wspecifier = WSpecifier.create(wspecifier) if ( wspecifier.spec_type == WSpecType.ARCHIVE or wspecifier.spec_type == WSpecType.BOTH ): if wspecifier.archive_type == ArchiveType.H5: return H5DW( wspecifier.archive, wspecifier.script, flush=wspecifier.flush, compress=compress, compression_method=compression_method, scp_sep=scp_sep, ) else: return ADW( wspecifier.archive, wspecifier.script, binary=wspecifier.binary, flush=wspecifier.flush, compress=compress, compression_method=compression_method, scp_sep=scp_sep, )
[docs] @staticmethod def filter_args(**kwargs): valid_args = ("scp_sep", "compress", "compression_method") return dict((k, kwargs[k]) for k in valid_args if k in kwargs)
[docs] @staticmethod def add_class_args(parser, prefix=None): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") parser.add_argument("--scp-sep", default=" ", help=("scp file field separator")) parser.add_argument("--compress", default=False, action="store_true") parser.add_argument( "--compression-method", default="auto", choices=compression_methods ) if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='data writer options')
[docs]class SequentialDataReaderFactory(object):
[docs] @staticmethod def create(rspecifier, path_prefix=None, scp_sep=" ", **kwargs): if isinstance(rspecifier, str): rspecifier = RSpecifier.create(rspecifier) if rspecifier.spec_type == RSpecType.ARCHIVE: if rspecifier.archive_type == ArchiveType.H5: return SH5FDR(rspecifier.archive, **kwargs) else: return SAFDR(rspecifier.archive, **kwargs) else: if rspecifier.archive_type == ArchiveType.H5: return SH5SDR(rspecifier.script, path_prefix, scp_sep=scp_sep, **kwargs) else: return SASDR(rspecifier.script, path_prefix, scp_sep=scp_sep, **kwargs)
[docs] @staticmethod def filter_args(**kwargs): valid_args = ("scp_sep", "path_prefix", "part_idx", "num_parts") return dict((k, kwargs[k]) for k in valid_args if k in kwargs)
[docs] @staticmethod def add_class_args(parser, prefix=None): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") try: parser.add_argument( "--scp-sep", default=" ", help=("scp file field separator") ) except: pass parser.add_argument( "--path-prefix", default=None, help=("scp file_path prefix") ) try: parser.add_argument( "--part-idx", type=int, default=1, help=("splits the list of files in num-parts " "and process part_idx"), ) parser.add_argument( "--num-parts", type=int, default=1, help=("splits the list of files in num-parts " "and process part_idx"), ) except: pass if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='data reader options')
[docs]class RandomAccessDataReaderFactory(object):
[docs] @staticmethod def create(rspecifier, path_prefix=None, transform=None, scp_sep=" "): if isinstance(rspecifier, str): rspecifier = RSpecifier.create(rspecifier) logging.debug(rspecifier.__dict__) if rspecifier.spec_type == RSpecType.ARCHIVE: if rspecifier.archive_type == ArchiveType.H5: return RH5FDR( rspecifier.archive, transform=transform, permissive=rspecifier.permissive, ) else: raise ValueError( "Random access to Ark file %s needs a script file" % rspecifier.archive ) else: if rspecifier.archive_type == ArchiveType.H5: return RH5SDR( rspecifier.archive, path_prefix, transform=transform, permissive=rspecifier.permissive, scp_sep=scp_sep, ) else: return RADR( rspecifier.script, path_prefix, transform=transform, permissive=rspecifier.permissive, scp_sep=scp_sep, )
[docs] @staticmethod def filter_args(**kwargs): valid_args = ("scp_sep", "path_prefix") return dict((k, kwargs[k]) for k in valid_args if k in kwargs)
[docs] @staticmethod def add_class_args(parser, prefix=None): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") try: parser.add_argument( "--scp-sep", default=" ", help=("scp file field separator") ) except: pass parser.add_argument( "--path-prefix", default=None, help=("scp file_path prefix") ) if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='data reader options') add_argparse_args = add_class_args