julien.blanchon
add app
c8c12e9
"""Base Anomaly Module for Training Task."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from abc import ABC
from typing import Any, List, Optional, Union
import pytorch_lightning as pl
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks.base import Callback
from torch import Tensor, nn
from anomalib.utils.metrics import (
AdaptiveThreshold,
AnomalyScoreDistribution,
MinMax,
get_metrics,
)
class AnomalyModule(pl.LightningModule, ABC):
"""AnomalyModule to train, validate, predict and test images.
Acts as a base class for all the Anomaly Modules in the library.
Args:
params (Union[DictConfig, ListConfig]): Configuration
"""
def __init__(self, params: Union[DictConfig, ListConfig]):
super().__init__()
# Force the type for hparams so that it works with OmegaConfig style of accessing
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(params)
self.loss: Tensor
self.callbacks: List[Callback]
self.image_threshold = AdaptiveThreshold(self.hparams.model.threshold.image_default).cpu()
self.pixel_threshold = AdaptiveThreshold(self.hparams.model.threshold.pixel_default).cpu()
self.training_distribution = AnomalyScoreDistribution().cpu()
self.min_max = MinMax().cpu()
self.model: nn.Module
# metrics
self.image_metrics, self.pixel_metrics = get_metrics(self.hparams)
self.image_metrics.set_threshold(self.hparams.model.threshold.image_default)
self.pixel_metrics.set_threshold(self.hparams.model.threshold.pixel_default)
def forward(self, batch): # pylint: disable=arguments-differ
"""Forward-pass input tensor to the module.
Args:
batch (Tensor): Input Tensor
Returns:
Tensor: Output tensor from the model.
"""
return self.model(batch)
def validation_step(self, batch, batch_idx) -> dict: # type: ignore # pylint: disable=arguments-differ
"""To be implemented in the subclasses."""
raise NotImplementedError
def predict_step(self, batch: Any, batch_idx: int, _dataloader_idx: Optional[int] = None) -> Any:
"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`.
Override to add any processing logic.
Args:
batch (Tensor): Current batch
batch_idx (int): Index of current batch
_dataloader_idx (int): Index of the current dataloader
Return:
Predicted output
"""
outputs = self.validation_step(batch, batch_idx)
self._post_process(outputs)
outputs["pred_labels"] = outputs["pred_scores"] >= self.image_threshold.value
if "anomaly_maps" in outputs.keys():
outputs["pred_masks"] = outputs["anomaly_maps"] >= self.pixel_threshold.value
return outputs
def test_step(self, batch, _): # pylint: disable=arguments-differ
"""Calls validation_step for anomaly map/score calculation.
Args:
batch (Tensor): Input batch
_: Index of the batch.
Returns:
Dictionary containing images, features, true labels and masks.
These are required in `validation_epoch_end` for feature concatenation.
"""
return self.validation_step(batch, _)
def validation_step_end(self, val_step_outputs): # pylint: disable=arguments-differ
"""Called at the end of each validation step."""
self._outputs_to_cpu(val_step_outputs)
self._post_process(val_step_outputs)
return val_step_outputs
def test_step_end(self, test_step_outputs): # pylint: disable=arguments-differ
"""Called at the end of each test step."""
self._outputs_to_cpu(test_step_outputs)
self._post_process(test_step_outputs)
return test_step_outputs
def validation_epoch_end(self, outputs):
"""Compute threshold and performance metrics.
Args:
outputs: Batch of outputs from the validation step
"""
if self.hparams.model.threshold.adaptive:
self._compute_adaptive_threshold(outputs)
self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs)
self._log_metrics()
def test_epoch_end(self, outputs):
"""Compute and save anomaly scores of the test set.
Args:
outputs: Batch of outputs from the validation step
"""
self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs)
self._log_metrics()
def _compute_adaptive_threshold(self, outputs):
self._collect_outputs(self.image_threshold, self.pixel_threshold, outputs)
self.image_threshold.compute()
if "mask" in outputs[0].keys() and "anomaly_maps" in outputs[0].keys():
self.pixel_threshold.compute()
else:
self.pixel_threshold.value = self.image_threshold.value
self.image_metrics.set_threshold(self.image_threshold.value.item())
self.pixel_metrics.set_threshold(self.pixel_threshold.value.item())
def _collect_outputs(self, image_metric, pixel_metric, outputs):
for output in outputs:
image_metric.cpu()
image_metric.update(output["pred_scores"], output["label"].int())
if "mask" in output.keys() and "anomaly_maps" in output.keys():
pixel_metric.cpu()
pixel_metric.update(output["anomaly_maps"].flatten(), output["mask"].flatten().int())
def _post_process(self, outputs):
"""Compute labels based on model predictions."""
if "pred_scores" not in outputs and "anomaly_maps" in outputs:
outputs["pred_scores"] = (
outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values
)
def _outputs_to_cpu(self, output):
# for output in outputs:
for key, value in output.items():
if isinstance(value, Tensor):
output[key] = value.cpu()
def _log_metrics(self):
"""Log computed performance metrics."""
self.log_dict(self.image_metrics)
if self.pixel_metrics.update_called:
self.log_dict(self.pixel_metrics)