Source code for hyperion.torch.metrics.accuracy

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import torch

from .metrics import TorchMetric
from .accuracy_functional import *


[docs]class CategoricalAccuracy(TorchMetric):
[docs] def __init__(self, weight=None, reduction="mean"): super(CategoricalAccuracy, self).__init__(weight=weight, reduction=reduction)
[docs] def forward(self, input, target): return categorical_accuracy( input, target, weight=self.weight, reduction=self.reduction )
[docs]class BinaryAccuracy(TorchMetric):
[docs] def __init__(self, weight=None, reduction="mean", thr=0.5): super(BinaryAccuracy, self).__init__(weight=weight, reduction=reduction) self.thr = thr
[docs] def forward(self, input, target): return binary_accuracy( input, target, weight=self.weight, reduction=self.reduction, thr=self.thr )
[docs]class BinaryAccuracyWithLogits(TorchMetric):
[docs] def __init__(self, weight=None, reduction="mean", thr=0.0): super(BinaryAccuracyWithLogits, self).__init__( weight=weight, reduction=reduction ) self.thr = thr
[docs] def forward(self, input, target): return binary_accuracy_with_logits( input, target, weight=self.weight, reduction=self.reduction, thr=self.thr )