Source code for hyperion.torch.narchs.torch_na_loader

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import torch

from .fcnet import FCNetV1

from .tdnn import TDNNV1
from .etdnn import ETDNNV1
from .resetdnn import ResETDNNV1

from .resnet import *

from .transformer_encoder_v1 import TransformerEncoderV1
from .conformer_encoder_v1 import ConformerEncoderV1

from .dc1d_encoder import DC1dEncoder
from .dc1d_decoder import DC1dDecoder
from .dc2d_encoder import DC2dEncoder
from .dc2d_decoder import DC2dDecoder

from .resnet1d_encoder import ResNet1dEncoder
from .resnet1d_decoder import ResNet1dDecoder
from .resnet2d_encoder import ResNet2dEncoder
from .resnet2d_decoder import ResNet2dDecoder

from .efficient_net import EfficientNet

from .classif_head import ClassifHead

from .audio_feats_mvn import AudioFeatsMVN


[docs]class TorchNALoader(object):
[docs] @staticmethod def load(file_path, extra_objs={}): model_data = torch.load(model_path) cfg = model_data["model_cfg"] class_name = cfg["class_name"] del cfg["class_name"] if class_name in globals(): class_obj = globals()[class_name] elif class_name in extra_objs: class_obj = extra_objs[class_name] else: raise Exception( "unknown neural architecture object with class_name=%s" % (class_name) ) state_dict = model_data["model_state_dict"] return class_obj.load(cfg=cfg, state_dict=state_dict)
[docs] @staticmethod def load_from_cfg(cfg, state_dict=None, extra_objs={}): class_name = cfg["class_name"] del cfg["class_name"] if class_name in globals(): class_obj = globals()[class_name] elif class_name in extra_objs: class_obj = extra_objs[class_name] else: raise Exception( "unknown neural architecture object with class_name=%s" % (class_name) ) return class_obj.load(cfg=cfg, state_dict=state_dict)