Spaces:
Build error
Build error
"""Implementation of Optimal F1 score based on TorchMetrics.""" | |
import torch | |
from torchmetrics import Metric, PrecisionRecallCurve | |
class OptimalF1(Metric): | |
"""Optimal F1 Metric. | |
Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the | |
predicted anomaly scores. | |
""" | |
def __init__(self, num_classes: int, **kwargs): | |
super().__init__(**kwargs) | |
self.precision_recall_curve = PrecisionRecallCurve(num_classes=num_classes, compute_on_step=False) | |
self.threshold: torch.Tensor | |
# pylint: disable=arguments-differ | |
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore | |
"""Update the precision-recall curve metric.""" | |
self.precision_recall_curve.update(preds, target) | |
def compute(self) -> torch.Tensor: | |
"""Compute the value of the optimal F1 score. | |
Compute the F1 scores while varying the threshold. Store the optimal | |
threshold as attribute and return the maximum value of the F1 score. | |
Returns: | |
Value of the F1 score at the optimal threshold. | |
""" | |
precision: torch.Tensor | |
recall: torch.Tensor | |
thresholds: torch.Tensor | |
precision, recall, thresholds = self.precision_recall_curve.compute() | |
f1_score = (2 * precision * recall) / (precision + recall + 1e-10) | |
self.threshold = thresholds[torch.argmax(f1_score)] | |
optimal_f1_score = torch.max(f1_score) | |
return optimal_f1_score | |