"""
Copyright 2019 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import numpy as np
import torch.distributed as dist
[docs]class Logger(object):
"""Base class for logger objects
Attributes:
params: training params dictionary
"""
[docs] def __init__(self):
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
self.cur_epoch = 0
self.cur_batch = 0
self.params = None
self.rank = rank
self.world_size = world_size
[docs] def on_epoch_begin(self, epoch, logs, **kwargs):
"""At the start of an epoch
Args:
epoch: index of the epoch
logs: dictionary of logs
"""
self.cur_epoch = epoch
[docs] def on_epoch_end(self, logs, **kwargs):
"""At the end of an epoch
Args:
logs: dictionary of logs
"""
pass
[docs] def on_batch_begin(self, batch, logs, **kwargs):
"""At the start of a batch
Args:
batch: batch index within the epoch
logs: dictionary of logs
"""
self.cur_batch = batch
[docs] def on_batch_end(self, logs, **kwargs):
"""At the end of a batch
Args:
batch: batch index within the epoch
logs: dictionary of logs
"""
pass
[docs] def on_train_begin(self, logs, **kwargs):
"""At the start of training
Args:
logs: dictionary of logs
"""
pass
[docs] def on_train_end(self, logs, **kwargs):
"""At the end of training
Args:
batch: batch index within the epoch
logs: dictionary of logs
"""
pass