julien.blanchon
add app
c8c12e9
"""GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training.
https://arxiv.org/abs/1805.06725
"""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import logging
from typing import Dict, List, Union
import torch
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks import EarlyStopping
from torch import Tensor, optim
from anomalib.data.utils.image import pad_nextpow2
from anomalib.models.components import AnomalyModule
from .torch_model import GanomalyModel
logger = logging.getLogger(__name__)
class GanomalyLightning(AnomalyModule):
"""PL Lightning Module for the GANomaly Algorithm.
Args:
hparams (Union[DictConfig, ListConfig]): Model parameters
"""
def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(hparams)
logger.info("Initializing Ganomaly Lightning model.")
self.model: GanomalyModel = GanomalyModel(
input_size=hparams.model.input_size,
num_input_channels=3,
n_features=hparams.model.n_features,
latent_vec_size=hparams.model.latent_vec_size,
extra_layers=hparams.model.extra_layers,
add_final_conv_layer=hparams.model.add_final_conv,
wadv=self.hparams.model.wadv,
wcon=self.hparams.model.wcon,
wenc=self.hparams.model.wenc,
)
self.real_label = torch.ones(size=(self.hparams.dataset.train_batch_size,), dtype=torch.float32)
self.fake_label = torch.zeros(size=(self.hparams.dataset.train_batch_size,), dtype=torch.float32)
self.min_scores: Tensor = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable
self.max_scores: Tensor = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable
def _reset_min_max(self):
"""Resets min_max scores."""
self.min_scores = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable
self.max_scores = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable
def configure_callbacks(self):
"""Configure model-specific callbacks."""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]
def configure_optimizers(self) -> List[optim.Optimizer]:
"""Configure optimizers for generator and discriminator.
Returns:
List[optim.Optimizer]: Adam optimizers for discriminator and generator.
"""
optimizer_d = optim.Adam(
self.model.discriminator.parameters(),
lr=self.hparams.model.lr,
betas=(self.hparams.model.beta1, self.hparams.model.beta2),
)
optimizer_g = optim.Adam(
self.model.generator.parameters(),
lr=self.hparams.model.lr,
betas=(self.hparams.model.beta1, self.hparams.model.beta2),
)
return [optimizer_d, optimizer_g]
def training_step(self, batch, _, optimizer_idx): # pylint: disable=arguments-differ
"""Training step.
Args:
batch (Dict): Input batch containing images.
optimizer_idx (int): Optimizer which is being called for current training step.
Returns:
Dict[str, Tensor]: Loss
"""
images = batch["image"]
padded_images = pad_nextpow2(images)
loss: Dict[str, Tensor]
# Discriminator
if optimizer_idx == 0:
# forward pass
loss_discriminator = self.model.get_discriminator_loss(padded_images)
loss = {"loss": loss_discriminator}
# Generator
else:
# forward pass
loss_generator = self.model.get_generator_loss(padded_images)
loss = {"loss": loss_generator}
return loss
def on_validation_start(self) -> None:
"""Reset min and max values for current validation epoch."""
self._reset_min_max()
return super().on_validation_start()
def validation_step(self, batch, _) -> Dict[str, Tensor]: # type: ignore # pylint: disable=arguments-differ
"""Update min and max scores from the current step.
Args:
batch (Dict[str, Tensor]): Predicted difference between z and z_hat.
Returns:
Dict[str, Tensor]: batch
"""
batch["pred_scores"] = self.model(batch["image"])
self.max_scores = max(self.max_scores, torch.max(batch["pred_scores"]))
self.min_scores = min(self.min_scores, torch.min(batch["pred_scores"]))
return batch
def validation_epoch_end(self, outputs):
"""Normalize outputs based on min/max values."""
logger.info("Normalizing validation outputs based on min/max values.")
for prediction in outputs:
prediction["pred_scores"] = self._normalize(prediction["pred_scores"])
super().validation_epoch_end(outputs)
return outputs
def on_test_start(self) -> None:
"""Reset min max values before test batch starts."""
self._reset_min_max()
return super().on_test_start()
def test_step(self, batch, _):
"""Update min and max scores from the current step."""
super().test_step(batch, _)
self.max_scores = max(self.max_scores, torch.max(batch["pred_scores"]))
self.min_scores = min(self.min_scores, torch.min(batch["pred_scores"]))
return batch
def test_epoch_end(self, outputs):
"""Normalize outputs based on min/max values."""
logger.info("Normalizing test outputs based on min/max values.")
for prediction in outputs:
prediction["pred_scores"] = self._normalize(prediction["pred_scores"])
super().test_epoch_end(outputs)
return outputs
def _normalize(self, scores: Tensor) -> Tensor:
"""Normalize the scores based on min/max of entire dataset.
Args:
scores (Tensor): Un-normalized scores.
Returns:
Tensor: Normalized scores.
"""
scores = (scores - self.min_scores.to(scores.device)) / (
self.max_scores.to(scores.device) - self.min_scores.to(scores.device)
)
return scores