Spaces:
Build error
Build error
File size: 814 Bytes
c8c12e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
"""Implementation of AUROC metric based on TorchMetrics."""
import torch
from torch import Tensor
from torchmetrics import ROC
from torchmetrics.functional import auc
class AUROC(ROC):
"""Area under the ROC curve."""
def compute(self) -> Tensor:
"""First compute ROC curve, then compute area under the curve.
Returns:
Value of the AUROC metric
"""
tpr: Tensor
fpr: Tensor
fpr, tpr, _thresholds = super().compute()
# TODO: use stable sort after upgrading to pytorch 1.9.x (https://github.com/openvinotoolkit/anomalib/issues/92)
if not (torch.all(fpr.diff() <= 0) or torch.all(fpr.diff() >= 0)):
return auc(fpr, tpr, reorder=True) # only reorder if fpr is not increasing or decreasing
return auc(fpr, tpr)
|