julien.blanchon
add app
c8c12e9
"""PyTorch model for the PaDiM model implementation."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from random import sample
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn
from anomalib.models.components import FeatureExtractor, MultiVariateGaussian
from anomalib.models.padim.anomaly_map import AnomalyMapGenerator
from anomalib.pre_processing import Tiler
DIMS = {
"resnet18": {"orig_dims": 448, "reduced_dims": 100, "emb_scale": 4},
"wide_resnet50_2": {"orig_dims": 1792, "reduced_dims": 550, "emb_scale": 4},
}
class PadimModel(nn.Module):
"""Padim Module.
Args:
layers (List[str]): Layers used for feature extraction
input_size (Tuple[int, int]): Input size for the model.
tile_size (Tuple[int, int]): Tile size
tile_stride (int): Stride for tiling
apply_tiling (bool, optional): Apply tiling. Defaults to False.
backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18".
"""
def __init__(
self,
layers: List[str],
input_size: Tuple[int, int],
backbone: str = "resnet18",
apply_tiling: bool = False,
tile_size: Optional[Tuple[int, int]] = None,
tile_stride: Optional[int] = None,
):
super().__init__()
self.backbone = getattr(torchvision.models, backbone)
self.layers = layers
self.apply_tiling = apply_tiling
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers)
self.dims = DIMS[backbone]
# pylint: disable=not-callable
# Since idx is randomly selected, save it with model to get same results
self.register_buffer(
"idx",
torch.tensor(sample(range(0, DIMS[backbone]["orig_dims"]), DIMS[backbone]["reduced_dims"])),
)
self.idx: Tensor
self.loss = None
self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size)
n_features = DIMS[backbone]["reduced_dims"]
patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"]
n_patches = patches_dims.ceil().prod().int().item()
self.gaussian = MultiVariateGaussian(n_features, n_patches)
if apply_tiling:
assert tile_size is not None
assert tile_stride is not None
self.tiler = Tiler(tile_size, tile_stride)
def forward(self, input_tensor: Tensor) -> Tensor:
"""Forward-pass image-batch (N, C, H, W) into model to extract features.
Args:
input_tensor: Image-batch (N, C, H, W)
input_tensor: Tensor:
Returns:
Features from single/multiple layers.
Example:
>>> x = torch.randn(32, 3, 224, 224)
>>> features = self.extract_features(input_tensor)
>>> features.keys()
dict_keys(['layer1', 'layer2', 'layer3'])
>>> [v.shape for v in features.values()]
[torch.Size([32, 64, 56, 56]),
torch.Size([32, 128, 28, 28]),
torch.Size([32, 256, 14, 14])]
"""
if self.apply_tiling:
input_tensor = self.tiler.tile(input_tensor)
with torch.no_grad():
features = self.feature_extractor(input_tensor)
embeddings = self.generate_embedding(features)
if self.apply_tiling:
embeddings = self.tiler.untile(embeddings)
if self.training:
output = embeddings
else:
output = self.anomaly_map_generator(
embedding=embeddings, mean=self.gaussian.mean, inv_covariance=self.gaussian.inv_covariance
)
return output
def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor:
"""Generate embedding from hierarchical feature map.
Args:
features (Dict[str, Tensor]): Hierarchical feature map from a CNN (ResNet18 or WideResnet)
Returns:
Embedding vector
"""
embeddings = features[self.layers[0]]
for layer in self.layers[1:]:
layer_embedding = features[layer]
layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="nearest")
embeddings = torch.cat((embeddings, layer_embedding), 1)
# subsample embeddings
idx = self.idx.to(embeddings.device)
embeddings = torch.index_select(embeddings, 1, idx)
return embeddings