"""PyTorch model for the PaDiM model implementation.""" # Copyright (C) 2020 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. from random import sample from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F import torchvision from torch import Tensor, nn from anomalib.models.components import FeatureExtractor, MultiVariateGaussian from anomalib.models.padim.anomaly_map import AnomalyMapGenerator from anomalib.pre_processing import Tiler DIMS = { "resnet18": {"orig_dims": 448, "reduced_dims": 100, "emb_scale": 4}, "wide_resnet50_2": {"orig_dims": 1792, "reduced_dims": 550, "emb_scale": 4}, } class PadimModel(nn.Module): """Padim Module. Args: layers (List[str]): Layers used for feature extraction input_size (Tuple[int, int]): Input size for the model. tile_size (Tuple[int, int]): Tile size tile_stride (int): Stride for tiling apply_tiling (bool, optional): Apply tiling. Defaults to False. backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". """ def __init__( self, layers: List[str], input_size: Tuple[int, int], backbone: str = "resnet18", apply_tiling: bool = False, tile_size: Optional[Tuple[int, int]] = None, tile_stride: Optional[int] = None, ): super().__init__() self.backbone = getattr(torchvision.models, backbone) self.layers = layers self.apply_tiling = apply_tiling self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers) self.dims = DIMS[backbone] # pylint: disable=not-callable # Since idx is randomly selected, save it with model to get same results self.register_buffer( "idx", torch.tensor(sample(range(0, DIMS[backbone]["orig_dims"]), DIMS[backbone]["reduced_dims"])), ) self.idx: Tensor self.loss = None self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size) n_features = DIMS[backbone]["reduced_dims"] patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"] n_patches = patches_dims.ceil().prod().int().item() self.gaussian = MultiVariateGaussian(n_features, n_patches) if apply_tiling: assert tile_size is not None assert tile_stride is not None self.tiler = Tiler(tile_size, tile_stride) def forward(self, input_tensor: Tensor) -> Tensor: """Forward-pass image-batch (N, C, H, W) into model to extract features. Args: input_tensor: Image-batch (N, C, H, W) input_tensor: Tensor: Returns: Features from single/multiple layers. Example: >>> x = torch.randn(32, 3, 224, 224) >>> features = self.extract_features(input_tensor) >>> features.keys() dict_keys(['layer1', 'layer2', 'layer3']) >>> [v.shape for v in features.values()] [torch.Size([32, 64, 56, 56]), torch.Size([32, 128, 28, 28]), torch.Size([32, 256, 14, 14])] """ if self.apply_tiling: input_tensor = self.tiler.tile(input_tensor) with torch.no_grad(): features = self.feature_extractor(input_tensor) embeddings = self.generate_embedding(features) if self.apply_tiling: embeddings = self.tiler.untile(embeddings) if self.training: output = embeddings else: output = self.anomaly_map_generator( embedding=embeddings, mean=self.gaussian.mean, inv_covariance=self.gaussian.inv_covariance ) return output def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor: """Generate embedding from hierarchical feature map. Args: features (Dict[str, Tensor]): Hierarchical feature map from a CNN (ResNet18 or WideResnet) Returns: Embedding vector """ embeddings = features[self.layers[0]] for layer in self.layers[1:]: layer_embedding = features[layer] layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="nearest") embeddings = torch.cat((embeddings, layer_embedding), 1) # subsample embeddings idx = self.idx.to(embeddings.device) embeddings = torch.index_select(embeddings, 1, idx) return embeddings