"""
Copyright 2019 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
from copy import deepcopy
import torch
import torch.nn as nn
[docs]class TorchModel(nn.Module):
[docs] def get_config(self):
config = {"class_name": self.__class__.__name__}
return config
[docs] def copy(self):
return deepcopy(self)
[docs] def save(self, file_path):
file_dir = os.path.dirname(file_path)
if not (os.path.isdir(file_dir)):
os.makedirs(file_dir, exist_ok=True)
config = self.get_config()
torch.save(
{"model_cfg": self.get_config(), "model_state_dict": self.state_dict()}
)
[docs] def freeze(self):
for param in self.parameters():
param.requires_grad = False
[docs] def unfreeze(self):
for param in self.parameters():
param.requires_grad = True
@staticmethod
def _load_cfg_state_dict(file_path=None, cfg=None, state_dict=None):
model_data = None
if cfg is None:
assert file_path is not None
model_data = torch.load(file_path)
if cfg is None:
cfg = model_data["model_cfg"]
if state_dict is None and model_data is not None:
state_dict = model_data["model_state_dict"]
if "class_name" in cfg:
del cfg["class_name"]
return cfg, state_dict
[docs] @classmethod
def load(cls, file_path=None, cfg=None, state_dict=None):
cfg, state_dict = TorchModel._load_cfg_state_dict(file_path, cfg, state_dict)
model = cls(**cfg)
if state_dict is not None:
model.load_state_dict(state_dict)
return model
[docs] def get_reg_loss(self):
return 0
[docs] def get_loss(self):
return 0
@property
def device(self):
devices = {param.device for param in self.parameters()} | {
buf.device for buf in self.buffers()
}
if len(devices) != 1:
raise RuntimeError(
"Cannot determine device: {} different devices found".format(
len(devices)
)
)
return next(iter(devices))