Spaces:
Build error
Build error
"""GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training. | |
https://arxiv.org/abs/1805.06725 | |
""" | |
# Copyright (C) 2020 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions | |
# and limitations under the License. | |
import logging | |
from typing import Dict, List, Union | |
import torch | |
from omegaconf import DictConfig, ListConfig | |
from pytorch_lightning.callbacks import EarlyStopping | |
from torch import Tensor, optim | |
from anomalib.data.utils.image import pad_nextpow2 | |
from anomalib.models.components import AnomalyModule | |
from .torch_model import GanomalyModel | |
logger = logging.getLogger(__name__) | |
class GanomalyLightning(AnomalyModule): | |
"""PL Lightning Module for the GANomaly Algorithm. | |
Args: | |
hparams (Union[DictConfig, ListConfig]): Model parameters | |
""" | |
def __init__(self, hparams: Union[DictConfig, ListConfig]): | |
super().__init__(hparams) | |
logger.info("Initializing Ganomaly Lightning model.") | |
self.model: GanomalyModel = GanomalyModel( | |
input_size=hparams.model.input_size, | |
num_input_channels=3, | |
n_features=hparams.model.n_features, | |
latent_vec_size=hparams.model.latent_vec_size, | |
extra_layers=hparams.model.extra_layers, | |
add_final_conv_layer=hparams.model.add_final_conv, | |
wadv=self.hparams.model.wadv, | |
wcon=self.hparams.model.wcon, | |
wenc=self.hparams.model.wenc, | |
) | |
self.real_label = torch.ones(size=(self.hparams.dataset.train_batch_size,), dtype=torch.float32) | |
self.fake_label = torch.zeros(size=(self.hparams.dataset.train_batch_size,), dtype=torch.float32) | |
self.min_scores: Tensor = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable | |
self.max_scores: Tensor = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable | |
def _reset_min_max(self): | |
"""Resets min_max scores.""" | |
self.min_scores = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable | |
self.max_scores = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable | |
def configure_callbacks(self): | |
"""Configure model-specific callbacks.""" | |
early_stopping = EarlyStopping( | |
monitor=self.hparams.model.early_stopping.metric, | |
patience=self.hparams.model.early_stopping.patience, | |
mode=self.hparams.model.early_stopping.mode, | |
) | |
return [early_stopping] | |
def configure_optimizers(self) -> List[optim.Optimizer]: | |
"""Configure optimizers for generator and discriminator. | |
Returns: | |
List[optim.Optimizer]: Adam optimizers for discriminator and generator. | |
""" | |
optimizer_d = optim.Adam( | |
self.model.discriminator.parameters(), | |
lr=self.hparams.model.lr, | |
betas=(self.hparams.model.beta1, self.hparams.model.beta2), | |
) | |
optimizer_g = optim.Adam( | |
self.model.generator.parameters(), | |
lr=self.hparams.model.lr, | |
betas=(self.hparams.model.beta1, self.hparams.model.beta2), | |
) | |
return [optimizer_d, optimizer_g] | |
def training_step(self, batch, _, optimizer_idx): # pylint: disable=arguments-differ | |
"""Training step. | |
Args: | |
batch (Dict): Input batch containing images. | |
optimizer_idx (int): Optimizer which is being called for current training step. | |
Returns: | |
Dict[str, Tensor]: Loss | |
""" | |
images = batch["image"] | |
padded_images = pad_nextpow2(images) | |
loss: Dict[str, Tensor] | |
# Discriminator | |
if optimizer_idx == 0: | |
# forward pass | |
loss_discriminator = self.model.get_discriminator_loss(padded_images) | |
loss = {"loss": loss_discriminator} | |
# Generator | |
else: | |
# forward pass | |
loss_generator = self.model.get_generator_loss(padded_images) | |
loss = {"loss": loss_generator} | |
return loss | |
def on_validation_start(self) -> None: | |
"""Reset min and max values for current validation epoch.""" | |
self._reset_min_max() | |
return super().on_validation_start() | |
def validation_step(self, batch, _) -> Dict[str, Tensor]: # type: ignore # pylint: disable=arguments-differ | |
"""Update min and max scores from the current step. | |
Args: | |
batch (Dict[str, Tensor]): Predicted difference between z and z_hat. | |
Returns: | |
Dict[str, Tensor]: batch | |
""" | |
batch["pred_scores"] = self.model(batch["image"]) | |
self.max_scores = max(self.max_scores, torch.max(batch["pred_scores"])) | |
self.min_scores = min(self.min_scores, torch.min(batch["pred_scores"])) | |
return batch | |
def validation_epoch_end(self, outputs): | |
"""Normalize outputs based on min/max values.""" | |
logger.info("Normalizing validation outputs based on min/max values.") | |
for prediction in outputs: | |
prediction["pred_scores"] = self._normalize(prediction["pred_scores"]) | |
super().validation_epoch_end(outputs) | |
return outputs | |
def on_test_start(self) -> None: | |
"""Reset min max values before test batch starts.""" | |
self._reset_min_max() | |
return super().on_test_start() | |
def test_step(self, batch, _): | |
"""Update min and max scores from the current step.""" | |
super().test_step(batch, _) | |
self.max_scores = max(self.max_scores, torch.max(batch["pred_scores"])) | |
self.min_scores = min(self.min_scores, torch.min(batch["pred_scores"])) | |
return batch | |
def test_epoch_end(self, outputs): | |
"""Normalize outputs based on min/max values.""" | |
logger.info("Normalizing test outputs based on min/max values.") | |
for prediction in outputs: | |
prediction["pred_scores"] = self._normalize(prediction["pred_scores"]) | |
super().test_epoch_end(outputs) | |
return outputs | |
def _normalize(self, scores: Tensor) -> Tensor: | |
"""Normalize the scores based on min/max of entire dataset. | |
Args: | |
scores (Tensor): Un-normalized scores. | |
Returns: | |
Tensor: Normalized scores. | |
""" | |
scores = (scores - self.min_scores.to(scores.device)) / ( | |
self.max_scores.to(scores.device) - self.min_scores.to(scores.device) | |
) | |
return scores | |