Source code for hyperion.torch.metrics.accuracy_functional

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""


import torch


[docs]def categorical_accuracy(input, target, weight=None, reduction="mean"): dim = input.dim() if dim < 2: raise ValueError("Expected 2 or more dimensions (got %d)" % (dim)) if input.size(0) != target.size(0): raise ValueError( "Expected input batch_size (%d) to match target batch_size (%d)." % (input.size(0), target.size(0)) ) with torch.no_grad(): _, pred = torch.max(input, dim=-1) if target.dim() == 2: _, target = torch.max(target, dim=-1) ok = pred.eq(target).float() if reduction == "none": return ok weight_mean = 1 if weight is not None: if input.size(0) != weight.size(0): raise ValueError( "Expected input batch_size (%d) to match weight batch_size (%d)." % (input.size(0), weight.size(0)) ) ok *= weight weight_mean = weight.mean() if reduction == "sum": return ok.sum().item() acc = ok.mean() / weight_mean return acc.item()
[docs]def binary_accuracy(input, target, weight=None, reduction="mean", thr=0.5): dim = input.dim() if dim < 2: raise ValueError("Expected 2 or more dimensions (got %d)" % (dim)) if not (target.size() == input.size()): raise ValueError( "Target size ({}) is different to the input size ({}).".format( target.size(), input.size() ) ) if input.numel() != target.numel(): raise ValueError( "Target and input must have the same number of elements. target nelement ({}) " "!= input nelement ({})".format(target.numel(), input.numel()) ) with torch.no_grad(): pred = input > thr ok = pred.eq(target).float() if reduction == "none": return ok weight_mean = 1 if weight is not None: if input.size(0) != weight.size(0): raise ValueError( "Expected input batch_size (%d) to match weight batch_size (%d)." % (input.size(0), weight.size(0)) ) if weight.dim() == 1: ok *= weight.unsqueeze(1) else: ok *= weight weight_mean = weight.mean() if reduction == "sum": return ok.sum().item() acc = ok.mean() / weight_mean return ok.item()
[docs]def binary_accuracy_with_logits(input, target, weight=None, reduction="mean", thr=0): return binary_accuracy(input, target, weight, reduction, thr)