Spaces:
Build error
Build error
"""CFLOW: Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows. | |
https://arxiv.org/pdf/2107.12571v1.pdf | |
""" | |
# Copyright (C) 2020 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions | |
# and limitations under the License. | |
import logging | |
import einops | |
import torch | |
import torch.nn.functional as F | |
from pytorch_lightning.callbacks import EarlyStopping | |
from torch import optim | |
from anomalib.models.cflow.torch_model import CflowModel | |
from anomalib.models.cflow.utils import get_logp, positional_encoding_2d | |
from anomalib.models.components import AnomalyModule | |
logger = logging.getLogger(__name__) | |
__all__ = ["CflowLightning"] | |
class CflowLightning(AnomalyModule): | |
"""PL Lightning Module for the CFLOW algorithm.""" | |
def __init__(self, hparams): | |
super().__init__(hparams) | |
logger.info("Initializing Cflow Lightning model.") | |
self.model: CflowModel = CflowModel(hparams) | |
self.loss_val = 0 | |
self.automatic_optimization = False | |
def configure_callbacks(self): | |
"""Configure model-specific callbacks.""" | |
early_stopping = EarlyStopping( | |
monitor=self.hparams.model.early_stopping.metric, | |
patience=self.hparams.model.early_stopping.patience, | |
mode=self.hparams.model.early_stopping.mode, | |
) | |
return [early_stopping] | |
def configure_optimizers(self) -> torch.optim.Optimizer: | |
"""Configures optimizers for each decoder. | |
Returns: | |
Optimizer: Adam optimizer for each decoder | |
""" | |
decoders_parameters = [] | |
for decoder_idx in range(len(self.model.pool_layers)): | |
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters())) | |
optimizer = optim.Adam( | |
params=decoders_parameters, | |
lr=self.hparams.model.lr, | |
) | |
return optimizer | |
def training_step(self, batch, _): # pylint: disable=arguments-differ | |
"""Training Step of CFLOW. | |
For each batch, decoder layers are trained with a dynamic fiber batch size. | |
Training step is performed manually as multiple training steps are involved | |
per batch of input images | |
Args: | |
batch: Input batch | |
_: Index of the batch. | |
Returns: | |
Loss value for the batch | |
""" | |
opt = self.optimizers() | |
self.model.encoder.eval() | |
images = batch["image"] | |
activation = self.model.encoder(images) | |
avg_loss = torch.zeros([1], dtype=torch.float64).to(images.device) | |
height = [] | |
width = [] | |
for layer_idx, layer in enumerate(self.model.pool_layers): | |
encoder_activations = activation[layer].detach() # BxCxHxW | |
batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size() | |
image_size = im_height * im_width | |
embedding_length = batch_size * image_size # number of rows in the conditional vector | |
height.append(im_height) | |
width.append(im_width) | |
# repeats positional encoding for the entire batch 1 C H W to B C H W | |
pos_encoding = einops.repeat( | |
positional_encoding_2d(self.model.condition_vector, im_height, im_width).unsqueeze(0), | |
"b c h w-> (tile b) c h w", | |
tile=batch_size, | |
).to(images.device) | |
c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP | |
e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC | |
perm = torch.randperm(embedding_length) # BHW | |
decoder = self.model.decoders[layer_idx].to(images.device) | |
fiber_batches = embedding_length // self.model.fiber_batch_size # number of fiber batches | |
assert fiber_batches > 0, "Make sure we have enough fibers, otherwise decrease N or batch-size!" | |
for batch_num in range(fiber_batches): # per-fiber processing | |
opt.zero_grad() | |
if batch_num < (fiber_batches - 1): | |
idx = torch.arange( | |
batch_num * self.model.fiber_batch_size, (batch_num + 1) * self.model.fiber_batch_size | |
) | |
else: # When non-full batch is encountered batch_num * N will go out of bounds | |
idx = torch.arange(batch_num * self.model.fiber_batch_size, embedding_length) | |
# get random vectors | |
c_p = c_r[perm[idx]] # NxP | |
e_p = e_r[perm[idx]] # NxC | |
# decoder returns the transformed variable z and the log Jacobian determinant | |
p_u, log_jac_det = decoder(e_p, [c_p]) | |
# | |
decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det) | |
log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim | |
loss = -F.logsigmoid(log_prob) | |
self.manual_backward(loss.mean()) | |
opt.step() | |
avg_loss += loss.sum() | |
return {"loss": avg_loss} | |
def validation_step(self, batch, _): # pylint: disable=arguments-differ | |
"""Validation Step of CFLOW. | |
Similar to the training step, encoder features | |
are extracted from the CNN for each batch, and anomaly | |
map is computed. | |
Args: | |
batch: Input batch | |
_: Index of the batch. | |
Returns: | |
Dictionary containing images, anomaly maps, true labels and masks. | |
These are required in `validation_epoch_end` for feature concatenation. | |
""" | |
batch["anomaly_maps"] = self.model(batch["image"]) | |
return batch | |