Spaces:
Build error
Build error
"""PyTorch model for the STFPM model implementation.""" | |
# Copyright (C) 2020 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions | |
# and limitations under the License. | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
import torch.nn.functional as F | |
import torchvision | |
from torch import Tensor, nn | |
from anomalib.models.components import FeatureExtractor | |
from anomalib.models.stfpm.anomaly_map import AnomalyMapGenerator | |
from anomalib.pre_processing import Tiler | |
class Loss(nn.Module): | |
"""Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper. | |
Example: | |
>>> from anomalib.models.components.feature_extractors.feature_extractor import FeatureExtractor | |
>>> from anomalib.models.stfpm.torch_model import Loss | |
>>> from torchvision.models import resnet18 | |
>>> layers = ['layer1', 'layer2', 'layer3'] | |
>>> teacher_model = FeatureExtractor(model=resnet18(pretrained=True), layers=layers) | |
>>> student_model = FeatureExtractor(model=resnet18(pretrained=False), layers=layers) | |
>>> loss = Loss() | |
>>> inp = torch.rand((4, 3, 256, 256)) | |
>>> teacher_features = teacher_model(inp) | |
>>> student_features = student_model(inp) | |
>>> loss(student_features, teacher_features) | |
tensor(51.2015, grad_fn=<SumBackward0>) | |
""" | |
def __init__(self): | |
super().__init__() | |
self.mse_loss = nn.MSELoss(reduction="sum") | |
def compute_layer_loss(self, teacher_feats: Tensor, student_feats: Tensor) -> Tensor: | |
"""Compute layer loss based on Equation (1) in Section 3.2 of the paper. | |
Args: | |
teacher_feats (Tensor): Teacher features | |
student_feats (Tensor): Student features | |
Returns: | |
L2 distance between teacher and student features. | |
""" | |
height, width = teacher_feats.shape[2:] | |
norm_teacher_features = F.normalize(teacher_feats) | |
norm_student_features = F.normalize(student_feats) | |
layer_loss = (0.5 / (width * height)) * self.mse_loss(norm_teacher_features, norm_student_features) | |
return layer_loss | |
def forward(self, teacher_features: Dict[str, Tensor], student_features: Dict[str, Tensor]) -> Tensor: | |
"""Compute the overall loss via the weighted average of the layer losses computed by the cosine similarity. | |
Args: | |
teacher_features (Dict[str, Tensor]): Teacher features | |
student_features (Dict[str, Tensor]): Student features | |
Returns: | |
Total loss, which is the weighted average of the layer losses. | |
""" | |
layer_losses: List[Tensor] = [] | |
for layer in teacher_features.keys(): | |
loss = self.compute_layer_loss(teacher_features[layer], student_features[layer]) | |
layer_losses.append(loss) | |
total_loss = torch.stack(layer_losses).sum() | |
return total_loss | |
class STFPMModel(nn.Module): | |
"""STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection. | |
Args: | |
layers (List[str]): Layers used for feature extraction | |
input_size (Tuple[int, int]): Input size for the model. | |
tile_size (Tuple[int, int]): Tile size | |
tile_stride (int): Stride for tiling | |
backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". | |
apply_tiling (bool, optional): Apply tiling. Defaults to False. | |
""" | |
def __init__( | |
self, | |
layers: List[str], | |
input_size: Tuple[int, int], | |
backbone: str = "resnet18", | |
apply_tiling: bool = False, | |
tile_size: Optional[Tuple[int, int]] = None, | |
tile_stride: Optional[int] = None, | |
): | |
super().__init__() | |
self.backbone = getattr(torchvision.models, backbone) | |
self.apply_tiling = apply_tiling | |
self.teacher_model = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=layers) | |
self.student_model = FeatureExtractor(backbone=self.backbone(pretrained=False), layers=layers) | |
# teacher model is fixed | |
for parameters in self.teacher_model.parameters(): | |
parameters.requires_grad = False | |
self.loss = Loss() | |
if self.apply_tiling: | |
assert tile_size is not None | |
assert tile_stride is not None | |
self.tiler = Tiler(tile_size, tile_stride) | |
self.anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(tile_size)) | |
else: | |
self.anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(input_size)) | |
def forward(self, images): | |
"""Forward-pass images into the network. | |
During the training mode the model extracts the features from the teacher and student networks. | |
During the evaluation mode, it returns the predicted anomaly map. | |
Args: | |
images (Tensor): Batch of images. | |
Returns: | |
Teacher and student features when in training mode, otherwise the predicted anomaly maps. | |
""" | |
if self.apply_tiling: | |
images = self.tiler.tile(images) | |
teacher_features: Dict[str, Tensor] = self.teacher_model(images) | |
student_features: Dict[str, Tensor] = self.student_model(images) | |
if self.training: | |
output = teacher_features, student_features | |
else: | |
output = self.anomaly_map_generator(teacher_features=teacher_features, student_features=student_features) | |
if self.apply_tiling: | |
output = self.tiler.untile(output) | |
return output | |