julien.blanchon
add app
c8c12e9
raw
history blame
2.76 kB
"""Custom anomaly evaluation metrics."""
import importlib
import warnings
from typing import List, Optional, Tuple, Union
import torchmetrics
from omegaconf import DictConfig, ListConfig
from .adaptive_threshold import AdaptiveThreshold
from .anomaly_score_distribution import AnomalyScoreDistribution
from .auroc import AUROC
from .collection import AnomalibMetricCollection
from .min_max import MinMax
from .optimal_f1 import OptimalF1
__all__ = ["AUROC", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"]
def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
"""Create metric collections based on the config.
Args:
config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf
Returns:
AnomalibMetricCollection: Image-level metric collection
AnomalibMetricCollection: Pixel-level metric collection
"""
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else []
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else []
image_metrics = metric_collection_from_names(image_metric_names, "image_")
pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_")
return image_metrics, pixel_metrics
def metric_collection_from_names(metric_names: List[str], prefix: Optional[str]) -> AnomalibMetricCollection:
"""Create a metric collection from a list of metric names.
The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module,
then in TorchMetrics package.
Args:
metric_names (List[str]): List of metric names to be included in the collection.
prefix (Optional[str]): prefix to assign to the metrics in the collection.
Returns:
AnomalibMetricCollection: Collection of metrics.
"""
metrics_module = importlib.import_module("anomalib.utils.metrics")
metrics = AnomalibMetricCollection([], prefix=prefix, compute_groups=False)
for metric_name in metric_names:
if hasattr(metrics_module, metric_name):
metric_cls = getattr(metrics_module, metric_name)
metrics.add_metrics(metric_cls(compute_on_step=False))
elif hasattr(torchmetrics, metric_name):
try:
metric_cls = getattr(torchmetrics, metric_name)
metrics.add_metrics(metric_cls(compute_on_step=False))
except TypeError:
warnings.warn(f"Incorrect constructor arguments for {metric_name} metric from TorchMetrics package.")
else:
warnings.warn(f"No metric with name {metric_name} found in Anomalib metrics or TorchMetrics.")
return metrics