julien.blanchon
add app
c8c12e9
raw
history blame contribute delete
814 Bytes
"""Implementation of AUROC metric based on TorchMetrics."""
import torch
from torch import Tensor
from torchmetrics import ROC
from torchmetrics.functional import auc
class AUROC(ROC):
"""Area under the ROC curve."""
def compute(self) -> Tensor:
"""First compute ROC curve, then compute area under the curve.
Returns:
Value of the AUROC metric
"""
tpr: Tensor
fpr: Tensor
fpr, tpr, _thresholds = super().compute()
# TODO: use stable sort after upgrading to pytorch 1.9.x (https://github.com/openvinotoolkit/anomalib/issues/92)
if not (torch.all(fpr.diff() <= 0) or torch.all(fpr.diff() >= 0)):
return auc(fpr, tpr, reorder=True) # only reorder if fpr is not increasing or decreasing
return auc(fpr, tpr)