Spaces:
Build error
Build error
File size: 5,626 Bytes
c8c12e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
"""PyTorch model for CFlow model implementation."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from typing import List, Union
import einops
import torch
import torchvision
from omegaconf import DictConfig, ListConfig
from torch import nn
from anomalib.models.cflow.anomaly_map import AnomalyMapGenerator
from anomalib.models.cflow.utils import cflow_head, get_logp, positional_encoding_2d
from anomalib.models.components import FeatureExtractor
class CflowModel(nn.Module):
"""CFLOW: Conditional Normalizing Flows."""
def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__()
self.backbone = getattr(torchvision.models, hparams.model.backbone)
self.fiber_batch_size = hparams.dataset.fiber_batch_size
self.condition_vector: int = hparams.model.condition_vector
self.dec_arch = hparams.model.decoder
self.pool_layers = hparams.model.layers
self.encoder = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.pool_layers)
self.pool_dims = self.encoder.out_dims
self.decoders = nn.ModuleList(
[
cflow_head(
condition_vector=self.condition_vector,
coupling_blocks=hparams.model.coupling_blocks,
clamp_alpha=hparams.model.clamp_alpha,
n_features=pool_dim,
permute_soft=hparams.model.soft_permutation,
)
for pool_dim in self.pool_dims
]
)
# encoder model is fixed
for parameters in self.encoder.parameters():
parameters.requires_grad = False
self.anomaly_map_generator = AnomalyMapGenerator(
image_size=tuple(hparams.model.input_size), pool_layers=self.pool_layers
)
def forward(self, images):
"""Forward-pass images into the network to extract encoder features and compute probability.
Args:
images: Batch of images.
Returns:
Predicted anomaly maps.
"""
self.encoder.eval()
self.decoders.eval()
with torch.no_grad():
activation = self.encoder(images)
distribution = [torch.Tensor(0).to(images.device) for _ in self.pool_layers]
height: List[int] = []
width: List[int] = []
for layer_idx, layer in enumerate(self.pool_layers):
encoder_activations = activation[layer] # BxCxHxW
batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size()
image_size = im_height * im_width
embedding_length = batch_size * image_size # number of rows in the conditional vector
height.append(im_height)
width.append(im_width)
# repeats positional encoding for the entire batch 1 C H W to B C H W
pos_encoding = einops.repeat(
positional_encoding_2d(self.condition_vector, im_height, im_width).unsqueeze(0),
"b c h w-> (tile b) c h w",
tile=batch_size,
).to(images.device)
c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP
e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC
decoder = self.decoders[layer_idx].to(images.device)
# Sometimes during validation, the last batch E / N is not a whole number. Hence we need to add 1.
# It is assumed that during training that E / N is a whole number as no errors were discovered during
# testing. In case it is observed in the future, we can use only this line and ensure that FIB is at
# least 1 or set `drop_last` in the dataloader to drop the last non-full batch.
fiber_batches = embedding_length // self.fiber_batch_size + int(
embedding_length % self.fiber_batch_size > 0
)
for batch_num in range(fiber_batches): # per-fiber processing
if batch_num < (fiber_batches - 1):
idx = torch.arange(batch_num * self.fiber_batch_size, (batch_num + 1) * self.fiber_batch_size)
else: # When non-full batch is encountered batch_num+1 * N will go out of bounds
idx = torch.arange(batch_num * self.fiber_batch_size, embedding_length)
c_p = c_r[idx] # NxP
e_p = e_r[idx] # NxC
# decoder returns the transformed variable z and the log Jacobian determinant
with torch.no_grad():
p_u, log_jac_det = decoder(e_p, [c_p])
#
decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det)
log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim
distribution[layer_idx] = torch.cat((distribution[layer_idx], log_prob))
output = self.anomaly_map_generator(distribution=distribution, height=height, width=width)
self.decoders.train()
return output.to(images.device)
|