Source code for hyperion.torch.loggers.prog_logger

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import time
import logging
from collections import OrderedDict

import numpy as np

from .logger import Logger


[docs]class ProgLogger(Logger): """Logger that prints training progress to stdout Attributes: metrics: list of metrics interval: number of batches between prints """
[docs] def __init__(self, metrics=None, interval=10): super().__init__() self.metrics = None if metrics is None else set(metrics) self.interval = interval self.epochs = 0 self.batches = 0 self.samples = 0 self.cur_epoch = 0 self.cur_batch = 0 self.cur_sample = 0 self.t0 = 0
[docs] def on_train_begin(self, logs=None, **kwargs): self.epochs = kwargs["epochs"]
[docs] def on_epoch_begin(self, epoch, logs=None, **kwargs): if self.rank != 0: return self.cur_epoch = epoch logging.info("epoch: %d/%d starts" % (epoch + 1, self.epochs)) if "samples" in kwargs: self.samples = kwargs["samples"] * self.world_size else: self.samples = 0 if "batches" in kwargs: self.batches = kwargs["batches"] else: self.batches = 0 self.cur_batch = 0 self.cur_sample = 0 self.t0 = time.time()
[docs] def on_batch_begin(self, batch, logs=None, **kwargs): self.cur_batch = batch
[docs] def on_batch_end(self, logs=None, **kwargs): if self.rank != 0: return batch_size = 0 if "batch_size" in kwargs: batch_size = kwargs["batch_size"] * self.world_size self.cur_sample += batch_size self.cur_batch += 1 if (self.cur_batch % self.interval) == 0: info = "epoch: %d/%d " % (self.cur_epoch + 1, self.epochs) etime, eta = self.estimate_epoch_time() if eta == None: info += " et: %s" % (etime) else: info += " et: %s eta: %s" % (etime, eta) if self.batches > 0: info += " batches: %d/%d(%d%%)" % ( self.cur_batch, self.batches, int(100 * self.cur_batch / self.batches), ) else: info += " batches: %d" % (self.cur_batch) if self.cur_sample > 0: if self.samples > 0: info += " samples: %d/%d(%d%%)" % ( self.cur_sample, self.samples, int(100 * self.cur_sample / self.samples), ) else: info += " samples: %d" % (self.cur_sample) for k, v in logs.items(): if self.metrics is None or k in self.metrics: info += " %s: %.6f" % (k, v) logging.info(info)
[docs] def on_epoch_end(self, logs=None, **kwargs): if self.rank != 0: return info = "epoch: %d/%d " % (self.cur_epoch + 1, self.epochs) for k, v in logs.items(): if self.metrics is None or k in self.metrics: info += " %s: %.6f" % (k, v)
[docs] def estimate_epoch_time(self): t1 = time.time() et = t1 - self.t0 if self.batches > 0: total_t = et / self.cur_batch * self.batches elif self.samples > 0 and self: total_t = et / self.cur_sample * self.samples else: total_t = -1 etime = self.sec2str(et) if total_t == -1: eta = None else: eta = self.sec2str(total_t - et) return etime, eta
[docs] @staticmethod def sec2str(t): t = time.gmtime(t) if t.tm_mday > 1: st = "%d:%02d:%02d:%02d" % (t.tm_mday - 1, t.tm_hour, t.tm_min, t.tm_sec) elif t.tm_hour > 0: st = "%d:%02d:%02d" % (t.tm_hour, t.tm_min, t.tm_sec) elif t.tm_min > 0: st = "%d:%02d" % (t.tm_min, t.tm_sec) else: st = "%ds" % (t.tm_sec) return st