Spaces:
Build error
Build error
"""Custom anomaly evaluation metrics.""" | |
import importlib | |
import warnings | |
from typing import List, Optional, Tuple, Union | |
import torchmetrics | |
from omegaconf import DictConfig, ListConfig | |
from .adaptive_threshold import AdaptiveThreshold | |
from .anomaly_score_distribution import AnomalyScoreDistribution | |
from .auroc import AUROC | |
from .collection import AnomalibMetricCollection | |
from .min_max import MinMax | |
from .optimal_f1 import OptimalF1 | |
__all__ = ["AUROC", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"] | |
def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]: | |
"""Create metric collections based on the config. | |
Args: | |
config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf | |
Returns: | |
AnomalibMetricCollection: Image-level metric collection | |
AnomalibMetricCollection: Pixel-level metric collection | |
""" | |
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else [] | |
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else [] | |
image_metrics = metric_collection_from_names(image_metric_names, "image_") | |
pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_") | |
return image_metrics, pixel_metrics | |
def metric_collection_from_names(metric_names: List[str], prefix: Optional[str]) -> AnomalibMetricCollection: | |
"""Create a metric collection from a list of metric names. | |
The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module, | |
then in TorchMetrics package. | |
Args: | |
metric_names (List[str]): List of metric names to be included in the collection. | |
prefix (Optional[str]): prefix to assign to the metrics in the collection. | |
Returns: | |
AnomalibMetricCollection: Collection of metrics. | |
""" | |
metrics_module = importlib.import_module("anomalib.utils.metrics") | |
metrics = AnomalibMetricCollection([], prefix=prefix, compute_groups=False) | |
for metric_name in metric_names: | |
if hasattr(metrics_module, metric_name): | |
metric_cls = getattr(metrics_module, metric_name) | |
metrics.add_metrics(metric_cls(compute_on_step=False)) | |
elif hasattr(torchmetrics, metric_name): | |
try: | |
metric_cls = getattr(torchmetrics, metric_name) | |
metrics.add_metrics(metric_cls(compute_on_step=False)) | |
except TypeError: | |
warnings.warn(f"Incorrect constructor arguments for {metric_name} metric from TorchMetrics package.") | |
else: | |
warnings.warn(f"No metric with name {metric_name} found in Anomalib metrics or TorchMetrics.") | |
return metrics | |