Spaces:
Build error
Build error
"""Visualizer Callback.""" | |
# Copyright (C) 2020 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions | |
# and limitations under the License. | |
from pathlib import Path | |
from typing import Any, Optional, cast | |
from warnings import warn | |
import pytorch_lightning as pl | |
from pytorch_lightning import Callback | |
from pytorch_lightning.utilities.types import STEP_OUTPUT | |
from skimage.segmentation import mark_boundaries | |
from anomalib.models.components import AnomalyModule | |
from anomalib.post_processing import Visualizer, compute_mask, superimpose_anomaly_map | |
from anomalib.pre_processing.transforms import Denormalize | |
from anomalib.utils import loggers | |
from anomalib.utils.loggers import AnomalibWandbLogger | |
from anomalib.utils.loggers.base import ImageLoggerBase | |
class VisualizerCallback(Callback): | |
"""Callback that visualizes the inference results of a model. | |
The callback generates a figure showing the original image, the ground truth segmentation mask, | |
the predicted error heat map, and the predicted segmentation mask. | |
To save the images to the filesystem, add the 'local' keyword to the `project.log_images_to` parameter in the | |
config.yaml file. | |
""" | |
def __init__(self, task: str, inputs_are_normalized: bool = True): | |
"""Visualizer callback.""" | |
self.task = task | |
self.inputs_are_normalized = inputs_are_normalized | |
def _add_images( | |
self, | |
visualizer: Visualizer, | |
module: AnomalyModule, | |
trainer: pl.Trainer, | |
filename: Path, | |
): | |
"""Save image to logger/local storage. | |
Saves the image in `visualizer.figure` to the respective loggers and local storage if specified in | |
`log_images_to` in `config.yaml` of the models. | |
Args: | |
visualizer (Visualizer): Visualizer object from which the `figure` is saved/logged. | |
module (AnomalyModule): Anomaly module which holds reference to `hparams`. | |
trainer (Trainer): Pytorch Lightning trainer which holds reference to `logger` | |
filename (Path): Path of the input image. This name is used as name for the generated image. | |
""" | |
# Store names of logger and the logger in a dict | |
available_loggers = { | |
type(logger).__name__.lower().rstrip("logger").lstrip("anomalib"): logger for logger in trainer.loggers | |
} | |
# save image to respective logger | |
for log_to in module.hparams.project.log_images_to: | |
if log_to in loggers.AVAILABLE_LOGGERS: | |
# check if logger object is same as the requested object | |
if log_to in available_loggers and isinstance(available_loggers[log_to], ImageLoggerBase): | |
logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy | |
logger.add_image( | |
image=visualizer.figure, | |
name=filename.parent.name + "_" + filename.name, | |
global_step=module.global_step, | |
) | |
else: | |
warn( | |
f"Requested {log_to} logging but logger object is of type: {type(module.logger)}." | |
f" Skipping logging to {log_to}" | |
) | |
else: | |
warn(f"{log_to} not in the list of supported image loggers.") | |
if "local" in module.hparams.project.log_images_to: | |
visualizer.save(Path(module.hparams.project.path) / "images" / filename.parent.name / filename.name) | |
def on_test_batch_end( | |
self, | |
trainer: pl.Trainer, | |
pl_module: AnomalyModule, | |
outputs: Optional[STEP_OUTPUT], | |
_batch: Any, | |
_batch_idx: int, | |
_dataloader_idx: int, | |
) -> None: | |
"""Log images at the end of every batch. | |
Args: | |
trainer (Trainer): Pytorch lightning trainer object (unused). | |
pl_module (LightningModule): Lightning modules derived from BaseAnomalyLightning object as | |
currently only they support logging images. | |
outputs (Dict[str, Any]): Outputs of the current test step. | |
_batch (Any): Input batch of the current test step (unused). | |
_batch_idx (int): Index of the current test batch (unused). | |
_dataloader_idx (int): Index of the dataloader that yielded the current batch (unused). | |
""" | |
assert outputs is not None | |
if self.inputs_are_normalized: | |
normalize = False # anomaly maps are already normalized | |
else: | |
normalize = True # raw anomaly maps. Still need to normalize | |
threshold = pl_module.pixel_metrics.threshold | |
for i, (filename, image, anomaly_map, pred_score, gt_label) in enumerate( | |
zip( | |
outputs["image_path"], | |
outputs["image"], | |
outputs["anomaly_maps"], | |
outputs["pred_scores"], | |
outputs["label"], | |
) | |
): | |
image = Denormalize()(image.cpu()) | |
anomaly_map = anomaly_map.cpu().numpy() | |
heat_map = superimpose_anomaly_map(anomaly_map, image, normalize=normalize) | |
pred_mask = compute_mask(anomaly_map, threshold) | |
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick") | |
num_cols = 6 if self.task == "segmentation" else 5 | |
visualizer = Visualizer(num_rows=1, num_cols=num_cols, figure_size=(12, 3)) | |
visualizer.add_image(image=image, title="Image") | |
if "mask" in outputs: | |
true_mask = outputs["mask"][i].cpu().numpy() * 255 | |
visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth") | |
visualizer.add_image(image=heat_map, title="Predicted Heat Map") | |
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask") | |
visualizer.add_image(image=vis_img, title="Segmentation Result") | |
image_classified = visualizer.add_text( | |
image=image, | |
text=f"""Pred: { "anomalous" if pred_score > threshold else "normal"}({pred_score:.3f}) \n | |
GT: {"anomalous" if bool(gt_label) else "normal"}""", | |
) | |
visualizer.add_image(image=image_classified, title="Classified Image") | |
self._add_images(visualizer, pl_module, trainer, Path(filename)) | |
visualizer.close() | |
def on_test_end(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None: | |
"""Sync logs. | |
Currently only ``AnomalibWandbLogger`` is called from this method. This is because logging as a single batch | |
ensures that all images appear as part of the same step. | |
Args: | |
_trainer (pl.Trainer): Pytorch Lightning trainer (unused) | |
pl_module (AnomalyModule): Anomaly module | |
""" | |
if pl_module.logger is not None and isinstance(pl_module.logger, AnomalibWandbLogger): | |
pl_module.logger.save() | |