Source code for hyperion.torch.loggers.tensorboard_logger

"""
 Copyright 2021 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
import re
from torch.utils.tensorboard import SummaryWriter

from .logger import Logger


[docs]class TensorBoardLogger(Logger): """Logger that sends training progress to tensorboard Attributes: tb_path: tensorboard output directory """
[docs] def __init__(self, tb_path, interval=10): super().__init__() self.tb_path = tb_path self.writer = None self.interval = interval self.batches = 0 self.cur_epoch = 0 self.cur_batch = 0
[docs] def on_train_begin(self, logs=None, **kwargs): if self.rank != 0: return self.writer = SummaryWriter(self.tb_path)
[docs] def on_epoch_begin(self, epoch, logs=None, **kwargs): if self.rank != 0: return self.cur_epoch = epoch if "batches" in kwargs: self.batches = kwargs["batches"] else: self.batches = 0 self.cur_batch = 0
[docs] def on_batch_end(self, logs=None, **kwargs): if self.rank != 0: return self.cur_batch += 1 if (self.cur_batch % self.interval) == 0: step = self.cur_epoch * self.batches + self.cur_batch for k, v in logs.items(): self.writer.add_scalar(k + "/global_steps", v, step)
[docs] def on_epoch_end(self, logs=None, **kwargs): if self.rank != 0: return for k, v in logs.items(): k = re.sub(r"^(train|val)_(.*)$", r"\2/\1", k) self.writer.add_scalar(k, v, self.cur_epoch + 1)
[docs] def on_train_end(self, logs=None, **kwargs): if self.rank != 0: return self.writer.close() self.writer = None