Spaces:
Build error
Build error
File size: 5,106 Bytes
c8c12e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
"""PyTorch model for the PaDiM model implementation."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from random import sample
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn
from anomalib.models.components import FeatureExtractor, MultiVariateGaussian
from anomalib.models.padim.anomaly_map import AnomalyMapGenerator
from anomalib.pre_processing import Tiler
DIMS = {
"resnet18": {"orig_dims": 448, "reduced_dims": 100, "emb_scale": 4},
"wide_resnet50_2": {"orig_dims": 1792, "reduced_dims": 550, "emb_scale": 4},
}
class PadimModel(nn.Module):
"""Padim Module.
Args:
layers (List[str]): Layers used for feature extraction
input_size (Tuple[int, int]): Input size for the model.
tile_size (Tuple[int, int]): Tile size
tile_stride (int): Stride for tiling
apply_tiling (bool, optional): Apply tiling. Defaults to False.
backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18".
"""
def __init__(
self,
layers: List[str],
input_size: Tuple[int, int],
backbone: str = "resnet18",
apply_tiling: bool = False,
tile_size: Optional[Tuple[int, int]] = None,
tile_stride: Optional[int] = None,
):
super().__init__()
self.backbone = getattr(torchvision.models, backbone)
self.layers = layers
self.apply_tiling = apply_tiling
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers)
self.dims = DIMS[backbone]
# pylint: disable=not-callable
# Since idx is randomly selected, save it with model to get same results
self.register_buffer(
"idx",
torch.tensor(sample(range(0, DIMS[backbone]["orig_dims"]), DIMS[backbone]["reduced_dims"])),
)
self.idx: Tensor
self.loss = None
self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size)
n_features = DIMS[backbone]["reduced_dims"]
patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"]
n_patches = patches_dims.ceil().prod().int().item()
self.gaussian = MultiVariateGaussian(n_features, n_patches)
if apply_tiling:
assert tile_size is not None
assert tile_stride is not None
self.tiler = Tiler(tile_size, tile_stride)
def forward(self, input_tensor: Tensor) -> Tensor:
"""Forward-pass image-batch (N, C, H, W) into model to extract features.
Args:
input_tensor: Image-batch (N, C, H, W)
input_tensor: Tensor:
Returns:
Features from single/multiple layers.
Example:
>>> x = torch.randn(32, 3, 224, 224)
>>> features = self.extract_features(input_tensor)
>>> features.keys()
dict_keys(['layer1', 'layer2', 'layer3'])
>>> [v.shape for v in features.values()]
[torch.Size([32, 64, 56, 56]),
torch.Size([32, 128, 28, 28]),
torch.Size([32, 256, 14, 14])]
"""
if self.apply_tiling:
input_tensor = self.tiler.tile(input_tensor)
with torch.no_grad():
features = self.feature_extractor(input_tensor)
embeddings = self.generate_embedding(features)
if self.apply_tiling:
embeddings = self.tiler.untile(embeddings)
if self.training:
output = embeddings
else:
output = self.anomaly_map_generator(
embedding=embeddings, mean=self.gaussian.mean, inv_covariance=self.gaussian.inv_covariance
)
return output
def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor:
"""Generate embedding from hierarchical feature map.
Args:
features (Dict[str, Tensor]): Hierarchical feature map from a CNN (ResNet18 or WideResnet)
Returns:
Embedding vector
"""
embeddings = features[self.layers[0]]
for layer in self.layers[1:]:
layer_embedding = features[layer]
layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="nearest")
embeddings = torch.cat((embeddings, layer_embedding), 1)
# subsample embeddings
idx = self.idx.to(embeddings.device)
embeddings = torch.index_select(embeddings, 1, idx)
return embeddings
|