julien.blanchon
add app
c8c12e9
"""Custom anomaly evaluation metrics."""
import importlib
import warnings
from typing import List, Optional, Tuple, Union
import torchmetrics
from omegaconf import DictConfig, ListConfig
from .adaptive_threshold import AdaptiveThreshold
from .anomaly_score_distribution import AnomalyScoreDistribution
from .auroc import AUROC
from .collection import AnomalibMetricCollection
from .min_max import MinMax
from .optimal_f1 import OptimalF1
__all__ = ["AUROC", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"]
def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
"""Create metric collections based on the config.
Args:
config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf
Returns:
AnomalibMetricCollection: Image-level metric collection
AnomalibMetricCollection: Pixel-level metric collection
"""
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else []
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else []
image_metrics = metric_collection_from_names(image_metric_names, "image_")
pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_")
return image_metrics, pixel_metrics
def metric_collection_from_names(metric_names: List[str], prefix: Optional[str]) -> AnomalibMetricCollection:
"""Create a metric collection from a list of metric names.
The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module,
then in TorchMetrics package.
Args:
metric_names (List[str]): List of metric names to be included in the collection.
prefix (Optional[str]): prefix to assign to the metrics in the collection.
Returns:
AnomalibMetricCollection: Collection of metrics.
"""
metrics_module = importlib.import_module("anomalib.utils.metrics")
metrics = AnomalibMetricCollection([], prefix=prefix, compute_groups=False)
for metric_name in metric_names:
if hasattr(metrics_module, metric_name):
metric_cls = getattr(metrics_module, metric_name)
metrics.add_metrics(metric_cls(compute_on_step=False))
elif hasattr(torchmetrics, metric_name):
try:
metric_cls = getattr(torchmetrics, metric_name)
metrics.add_metrics(metric_cls(compute_on_step=False))
except TypeError:
warnings.warn(f"Incorrect constructor arguments for {metric_name} metric from TorchMetrics package.")
else:
warnings.warn(f"No metric with name {metric_name} found in Anomalib metrics or TorchMetrics.")
return metrics