Spaces:
Build error
Build error
"""Base Anomaly Module for Training Task.""" | |
# Copyright (C) 2020 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions | |
# and limitations under the License. | |
from abc import ABC | |
from typing import Any, List, Optional, Union | |
import pytorch_lightning as pl | |
from omegaconf import DictConfig, ListConfig | |
from pytorch_lightning.callbacks.base import Callback | |
from torch import Tensor, nn | |
from anomalib.utils.metrics import ( | |
AdaptiveThreshold, | |
AnomalyScoreDistribution, | |
MinMax, | |
get_metrics, | |
) | |
class AnomalyModule(pl.LightningModule, ABC): | |
"""AnomalyModule to train, validate, predict and test images. | |
Acts as a base class for all the Anomaly Modules in the library. | |
Args: | |
params (Union[DictConfig, ListConfig]): Configuration | |
""" | |
def __init__(self, params: Union[DictConfig, ListConfig]): | |
super().__init__() | |
# Force the type for hparams so that it works with OmegaConfig style of accessing | |
self.hparams: Union[DictConfig, ListConfig] # type: ignore | |
self.save_hyperparameters(params) | |
self.loss: Tensor | |
self.callbacks: List[Callback] | |
self.image_threshold = AdaptiveThreshold(self.hparams.model.threshold.image_default).cpu() | |
self.pixel_threshold = AdaptiveThreshold(self.hparams.model.threshold.pixel_default).cpu() | |
self.training_distribution = AnomalyScoreDistribution().cpu() | |
self.min_max = MinMax().cpu() | |
self.model: nn.Module | |
# metrics | |
self.image_metrics, self.pixel_metrics = get_metrics(self.hparams) | |
self.image_metrics.set_threshold(self.hparams.model.threshold.image_default) | |
self.pixel_metrics.set_threshold(self.hparams.model.threshold.pixel_default) | |
def forward(self, batch): # pylint: disable=arguments-differ | |
"""Forward-pass input tensor to the module. | |
Args: | |
batch (Tensor): Input Tensor | |
Returns: | |
Tensor: Output tensor from the model. | |
""" | |
return self.model(batch) | |
def validation_step(self, batch, batch_idx) -> dict: # type: ignore # pylint: disable=arguments-differ | |
"""To be implemented in the subclasses.""" | |
raise NotImplementedError | |
def predict_step(self, batch: Any, batch_idx: int, _dataloader_idx: Optional[int] = None) -> Any: | |
"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. | |
By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. | |
Override to add any processing logic. | |
Args: | |
batch (Tensor): Current batch | |
batch_idx (int): Index of current batch | |
_dataloader_idx (int): Index of the current dataloader | |
Return: | |
Predicted output | |
""" | |
outputs = self.validation_step(batch, batch_idx) | |
self._post_process(outputs) | |
outputs["pred_labels"] = outputs["pred_scores"] >= self.image_threshold.value | |
if "anomaly_maps" in outputs.keys(): | |
outputs["pred_masks"] = outputs["anomaly_maps"] >= self.pixel_threshold.value | |
return outputs | |
def test_step(self, batch, _): # pylint: disable=arguments-differ | |
"""Calls validation_step for anomaly map/score calculation. | |
Args: | |
batch (Tensor): Input batch | |
_: Index of the batch. | |
Returns: | |
Dictionary containing images, features, true labels and masks. | |
These are required in `validation_epoch_end` for feature concatenation. | |
""" | |
return self.validation_step(batch, _) | |
def validation_step_end(self, val_step_outputs): # pylint: disable=arguments-differ | |
"""Called at the end of each validation step.""" | |
self._outputs_to_cpu(val_step_outputs) | |
self._post_process(val_step_outputs) | |
return val_step_outputs | |
def test_step_end(self, test_step_outputs): # pylint: disable=arguments-differ | |
"""Called at the end of each test step.""" | |
self._outputs_to_cpu(test_step_outputs) | |
self._post_process(test_step_outputs) | |
return test_step_outputs | |
def validation_epoch_end(self, outputs): | |
"""Compute threshold and performance metrics. | |
Args: | |
outputs: Batch of outputs from the validation step | |
""" | |
if self.hparams.model.threshold.adaptive: | |
self._compute_adaptive_threshold(outputs) | |
self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs) | |
self._log_metrics() | |
def test_epoch_end(self, outputs): | |
"""Compute and save anomaly scores of the test set. | |
Args: | |
outputs: Batch of outputs from the validation step | |
""" | |
self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs) | |
self._log_metrics() | |
def _compute_adaptive_threshold(self, outputs): | |
self._collect_outputs(self.image_threshold, self.pixel_threshold, outputs) | |
self.image_threshold.compute() | |
if "mask" in outputs[0].keys() and "anomaly_maps" in outputs[0].keys(): | |
self.pixel_threshold.compute() | |
else: | |
self.pixel_threshold.value = self.image_threshold.value | |
self.image_metrics.set_threshold(self.image_threshold.value.item()) | |
self.pixel_metrics.set_threshold(self.pixel_threshold.value.item()) | |
def _collect_outputs(self, image_metric, pixel_metric, outputs): | |
for output in outputs: | |
image_metric.cpu() | |
image_metric.update(output["pred_scores"], output["label"].int()) | |
if "mask" in output.keys() and "anomaly_maps" in output.keys(): | |
pixel_metric.cpu() | |
pixel_metric.update(output["anomaly_maps"].flatten(), output["mask"].flatten().int()) | |
def _post_process(self, outputs): | |
"""Compute labels based on model predictions.""" | |
if "pred_scores" not in outputs and "anomaly_maps" in outputs: | |
outputs["pred_scores"] = ( | |
outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values | |
) | |
def _outputs_to_cpu(self, output): | |
# for output in outputs: | |
for key, value in output.items(): | |
if isinstance(value, Tensor): | |
output[key] = value.cpu() | |
def _log_metrics(self): | |
"""Log computed performance metrics.""" | |
self.log_dict(self.image_metrics) | |
if self.pixel_metrics.update_called: | |
self.log_dict(self.pixel_metrics) | |