Source code for hyperion.torch.utils.metric_acc

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
import logging
from collections import OrderedDict as ODict
import numpy as np

import torch
import torch.distributed as dist


[docs]class MetricAcc(object): """Class to accumulate metrics during an epoch."""
[docs] def __init__(self, device=None): self.keys = None self.acc = None self.count = 0 self.device = device try: rank = dist.get_rank() world_size = dist.get_world_size() except: rank = 0 world_size = 1 self.rank = rank self.world_size = world_size
[docs] def reset(self): """Resets the accumulators.""" self.count = 0 if self.acc is not None: self.acc[:] = 0
def _reduce(self, metrics): if self.world_size == 1: return metrics_list = [v for k, v in metrics.items()] metrics_tensor = torch.tensor(metrics_list, device=self.device) dist.reduce(metrics_tensor, 0, op=dist.ReduceOp.SUM) metrics_tensor /= self.world_size for i, k in enumerate(metrics.keys()): metrics[k] = metrics_tensor[i]
[docs] def update(self, metrics, num_samples=1): """Updates the values of the metric It uses recursive formula, it may be more numerically stable m^(i) = m^(i-1) + n^(i)/sum(n^(i)) (x^(i) - m^(i-1)) where i is the batch number, m^(i) is the accumulated average of the metric at batch i, x^(i) is the average of the metric at batch i, n^(i) is the batch_size at batch i. Args: metrics: dictionary with metrics for current batch num_samples: number of samples in current batch (batch_size) """ self._reduce(metrics) if self.rank != 0: return if self.keys is None: self.keys = metrics.keys() self.acc = np.zeros((len(self.keys),)) self.count += num_samples r = num_samples / self.count for i, k in enumerate(self.keys): self.acc[i] += r * (metrics[k] - self.acc[i])
@property def metrics(self): """Returns metrics dictionary""" if self.rank != 0: return {} logs = ODict() for i, k in enumerate(self.keys): logs[k] = self.acc[i] return logs