Spaces:
Build error
Build error
"""Implementation of Optimal F1 score based on TorchMetrics.""" | |
import torch | |
from torchmetrics import Metric, PrecisionRecallCurve | |
class AdaptiveThreshold(Metric): | |
"""Optimal F1 Metric. | |
Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the | |
predicted anomaly scores. | |
""" | |
def __init__(self, default_value: float, **kwargs): | |
super().__init__(**kwargs) | |
self.precision_recall_curve = PrecisionRecallCurve(num_classes=1, compute_on_step=False) | |
self.add_state("value", default=torch.tensor(default_value), persistent=True) # pylint: disable=not-callable | |
self.value = torch.tensor(default_value) # pylint: disable=not-callable | |
# pylint: disable=arguments-differ | |
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore | |
"""Update the precision-recall curve metric.""" | |
self.precision_recall_curve.update(preds, target) | |
def compute(self) -> torch.Tensor: | |
"""Compute the threshold that yields the optimal F1 score. | |
Compute the F1 scores while varying the threshold. Store the optimal | |
threshold as attribute and return the maximum value of the F1 score. | |
Returns: | |
Value of the F1 score at the optimal threshold. | |
""" | |
precision: torch.Tensor | |
recall: torch.Tensor | |
thresholds: torch.Tensor | |
precision, recall, thresholds = self.precision_recall_curve.compute() | |
f1_score = (2 * precision * recall) / (precision + recall + 1e-10) | |
if thresholds.dim() == 0: | |
# special case where recall is 1.0 even for the highest threshold. | |
# In this case 'thresholds' will be scalar. | |
self.value = thresholds | |
else: | |
self.value = thresholds[torch.argmax(f1_score)] | |
return self.value | |