"""PyTorch model for CFlow model implementation.""" # Copyright (C) 2020 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. from typing import List, Union import einops import torch import torchvision from omegaconf import DictConfig, ListConfig from torch import nn from anomalib.models.cflow.anomaly_map import AnomalyMapGenerator from anomalib.models.cflow.utils import cflow_head, get_logp, positional_encoding_2d from anomalib.models.components import FeatureExtractor class CflowModel(nn.Module): """CFLOW: Conditional Normalizing Flows.""" def __init__(self, hparams: Union[DictConfig, ListConfig]): super().__init__() self.backbone = getattr(torchvision.models, hparams.model.backbone) self.fiber_batch_size = hparams.dataset.fiber_batch_size self.condition_vector: int = hparams.model.condition_vector self.dec_arch = hparams.model.decoder self.pool_layers = hparams.model.layers self.encoder = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.pool_layers) self.pool_dims = self.encoder.out_dims self.decoders = nn.ModuleList( [ cflow_head( condition_vector=self.condition_vector, coupling_blocks=hparams.model.coupling_blocks, clamp_alpha=hparams.model.clamp_alpha, n_features=pool_dim, permute_soft=hparams.model.soft_permutation, ) for pool_dim in self.pool_dims ] ) # encoder model is fixed for parameters in self.encoder.parameters(): parameters.requires_grad = False self.anomaly_map_generator = AnomalyMapGenerator( image_size=tuple(hparams.model.input_size), pool_layers=self.pool_layers ) def forward(self, images): """Forward-pass images into the network to extract encoder features and compute probability. Args: images: Batch of images. Returns: Predicted anomaly maps. """ self.encoder.eval() self.decoders.eval() with torch.no_grad(): activation = self.encoder(images) distribution = [torch.Tensor(0).to(images.device) for _ in self.pool_layers] height: List[int] = [] width: List[int] = [] for layer_idx, layer in enumerate(self.pool_layers): encoder_activations = activation[layer] # BxCxHxW batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size() image_size = im_height * im_width embedding_length = batch_size * image_size # number of rows in the conditional vector height.append(im_height) width.append(im_width) # repeats positional encoding for the entire batch 1 C H W to B C H W pos_encoding = einops.repeat( positional_encoding_2d(self.condition_vector, im_height, im_width).unsqueeze(0), "b c h w-> (tile b) c h w", tile=batch_size, ).to(images.device) c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC decoder = self.decoders[layer_idx].to(images.device) # Sometimes during validation, the last batch E / N is not a whole number. Hence we need to add 1. # It is assumed that during training that E / N is a whole number as no errors were discovered during # testing. In case it is observed in the future, we can use only this line and ensure that FIB is at # least 1 or set `drop_last` in the dataloader to drop the last non-full batch. fiber_batches = embedding_length // self.fiber_batch_size + int( embedding_length % self.fiber_batch_size > 0 ) for batch_num in range(fiber_batches): # per-fiber processing if batch_num < (fiber_batches - 1): idx = torch.arange(batch_num * self.fiber_batch_size, (batch_num + 1) * self.fiber_batch_size) else: # When non-full batch is encountered batch_num+1 * N will go out of bounds idx = torch.arange(batch_num * self.fiber_batch_size, embedding_length) c_p = c_r[idx] # NxP e_p = e_r[idx] # NxC # decoder returns the transformed variable z and the log Jacobian determinant with torch.no_grad(): p_u, log_jac_det = decoder(e_p, [c_p]) # decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det) log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim distribution[layer_idx] = torch.cat((distribution[layer_idx], log_prob)) output = self.anomaly_map_generator(distribution=distribution, height=height, width=width) self.decoders.train() return output.to(images.device)