julien.blanchon
add app
c8c12e9
"""CFLOW: Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows.
https://arxiv.org/pdf/2107.12571v1.pdf
"""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import logging
import einops
import torch
import torch.nn.functional as F
from pytorch_lightning.callbacks import EarlyStopping
from torch import optim
from anomalib.models.cflow.torch_model import CflowModel
from anomalib.models.cflow.utils import get_logp, positional_encoding_2d
from anomalib.models.components import AnomalyModule
logger = logging.getLogger(__name__)
__all__ = ["CflowLightning"]
class CflowLightning(AnomalyModule):
"""PL Lightning Module for the CFLOW algorithm."""
def __init__(self, hparams):
super().__init__(hparams)
logger.info("Initializing Cflow Lightning model.")
self.model: CflowModel = CflowModel(hparams)
self.loss_val = 0
self.automatic_optimization = False
def configure_callbacks(self):
"""Configure model-specific callbacks."""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures optimizers for each decoder.
Returns:
Optimizer: Adam optimizer for each decoder
"""
decoders_parameters = []
for decoder_idx in range(len(self.model.pool_layers)):
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters()))
optimizer = optim.Adam(
params=decoders_parameters,
lr=self.hparams.model.lr,
)
return optimizer
def training_step(self, batch, _): # pylint: disable=arguments-differ
"""Training Step of CFLOW.
For each batch, decoder layers are trained with a dynamic fiber batch size.
Training step is performed manually as multiple training steps are involved
per batch of input images
Args:
batch: Input batch
_: Index of the batch.
Returns:
Loss value for the batch
"""
opt = self.optimizers()
self.model.encoder.eval()
images = batch["image"]
activation = self.model.encoder(images)
avg_loss = torch.zeros([1], dtype=torch.float64).to(images.device)
height = []
width = []
for layer_idx, layer in enumerate(self.model.pool_layers):
encoder_activations = activation[layer].detach() # BxCxHxW
batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size()
image_size = im_height * im_width
embedding_length = batch_size * image_size # number of rows in the conditional vector
height.append(im_height)
width.append(im_width)
# repeats positional encoding for the entire batch 1 C H W to B C H W
pos_encoding = einops.repeat(
positional_encoding_2d(self.model.condition_vector, im_height, im_width).unsqueeze(0),
"b c h w-> (tile b) c h w",
tile=batch_size,
).to(images.device)
c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP
e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC
perm = torch.randperm(embedding_length) # BHW
decoder = self.model.decoders[layer_idx].to(images.device)
fiber_batches = embedding_length // self.model.fiber_batch_size # number of fiber batches
assert fiber_batches > 0, "Make sure we have enough fibers, otherwise decrease N or batch-size!"
for batch_num in range(fiber_batches): # per-fiber processing
opt.zero_grad()
if batch_num < (fiber_batches - 1):
idx = torch.arange(
batch_num * self.model.fiber_batch_size, (batch_num + 1) * self.model.fiber_batch_size
)
else: # When non-full batch is encountered batch_num * N will go out of bounds
idx = torch.arange(batch_num * self.model.fiber_batch_size, embedding_length)
# get random vectors
c_p = c_r[perm[idx]] # NxP
e_p = e_r[perm[idx]] # NxC
# decoder returns the transformed variable z and the log Jacobian determinant
p_u, log_jac_det = decoder(e_p, [c_p])
#
decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det)
log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim
loss = -F.logsigmoid(log_prob)
self.manual_backward(loss.mean())
opt.step()
avg_loss += loss.sum()
return {"loss": avg_loss}
def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Validation Step of CFLOW.
Similar to the training step, encoder features
are extracted from the CNN for each batch, and anomaly
map is computed.
Args:
batch: Input batch
_: Index of the batch.
Returns:
Dictionary containing images, anomaly maps, true labels and masks.
These are required in `validation_epoch_end` for feature concatenation.
"""
batch["anomaly_maps"] = self.model(batch["image"])
return batch