File size: 2,763 Bytes
c8c12e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Custom anomaly evaluation metrics."""
import importlib
import warnings
from typing import List, Optional, Tuple, Union

import torchmetrics
from omegaconf import DictConfig, ListConfig

from .adaptive_threshold import AdaptiveThreshold
from .anomaly_score_distribution import AnomalyScoreDistribution
from .auroc import AUROC
from .collection import AnomalibMetricCollection
from .min_max import MinMax
from .optimal_f1 import OptimalF1

__all__ = ["AUROC", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"]


def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
    """Create metric collections based on the config.

    Args:
        config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf

    Returns:
        AnomalibMetricCollection: Image-level metric collection
        AnomalibMetricCollection: Pixel-level metric collection
    """
    image_metric_names = config.metrics.image if "image" in config.metrics.keys() else []
    pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else []
    image_metrics = metric_collection_from_names(image_metric_names, "image_")
    pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_")
    return image_metrics, pixel_metrics


def metric_collection_from_names(metric_names: List[str], prefix: Optional[str]) -> AnomalibMetricCollection:
    """Create a metric collection from a list of metric names.

    The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module,
    then in TorchMetrics package.

    Args:
        metric_names (List[str]): List of metric names to be included in the collection.
        prefix (Optional[str]): prefix to assign to the metrics in the collection.

    Returns:
        AnomalibMetricCollection: Collection of metrics.
    """
    metrics_module = importlib.import_module("anomalib.utils.metrics")
    metrics = AnomalibMetricCollection([], prefix=prefix, compute_groups=False)
    for metric_name in metric_names:
        if hasattr(metrics_module, metric_name):
            metric_cls = getattr(metrics_module, metric_name)
            metrics.add_metrics(metric_cls(compute_on_step=False))
        elif hasattr(torchmetrics, metric_name):
            try:
                metric_cls = getattr(torchmetrics, metric_name)
                metrics.add_metrics(metric_cls(compute_on_step=False))
            except TypeError:
                warnings.warn(f"Incorrect constructor arguments for {metric_name} metric from TorchMetrics package.")
        else:
            warnings.warn(f"No metric with name {metric_name} found in Anomalib metrics or TorchMetrics.")
    return metrics