"""Image Tiler.""" # Copyright (C) 2020 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. from itertools import product from math import ceil from typing import Optional, Sequence, Tuple, Union import torch import torchvision.transforms as T from torch import Tensor from torch.nn import functional as F class StrideSizeError(Exception): """StrideSizeError to raise exception when stride size is greater than the tile size.""" def compute_new_image_size(image_size: Tuple, tile_size: Tuple, stride: Tuple) -> Tuple: """This function checks if image size is divisible by tile size and stride. If not divisible, it resizes the image size to make it divisible. Args: image_size (Tuple): Original image size tile_size (Tuple): Tile size stride (Tuple): Stride Examples: >>> compute_new_image_size(image_size=(512, 512), tile_size=(256, 256), stride=(128, 128)) (512, 512) >>> compute_new_image_size(image_size=(512, 512), tile_size=(222, 222), stride=(111, 111)) (555, 555) Returns: Tuple: Updated image size that is divisible by tile size and stride. """ def __compute_new_edge_size(edge_size: int, tile_size: int, stride: int) -> int: """This function makes the resizing within the edge level.""" if (edge_size - tile_size) % stride != 0: edge_size = (ceil((edge_size - tile_size) / stride) * stride) + tile_size return edge_size resized_h = __compute_new_edge_size(image_size[0], tile_size[0], stride[0]) resized_w = __compute_new_edge_size(image_size[1], tile_size[1], stride[1]) return resized_h, resized_w def upscale_image(image: Tensor, size: Tuple, mode: str = "padding") -> Tensor: """Upscale image to the desired size via either padding or interpolation. Args: image (Tensor): Image size (Tuple): Tuple to which image is upscaled. mode (str, optional): Upscaling mode. Defaults to "padding". Examples: >>> image = torch.rand(1, 3, 512, 512) >>> image = upscale_image(image, size=(555, 555), mode="padding") >>> image.shape torch.Size([1, 3, 555, 555]) >>> image = torch.rand(1, 3, 512, 512) >>> image = upscale_image(image, size=(555, 555), mode="interpolation") >>> image.shape torch.Size([1, 3, 555, 555]) Returns: Tensor: Upscaled image. """ image_h, image_w = image.shape[2:] resize_h, resize_w = size if mode == "padding": pad_h = resize_h - image_h pad_w = resize_w - image_w image = F.pad(image, [0, pad_w, 0, pad_h]) elif mode == "interpolation": image = F.interpolate(input=image, size=(resize_h, resize_w)) else: raise ValueError(f"Unknown mode {mode}. Only padding and interpolation is available.") return image def downscale_image(image: Tensor, size: Tuple, mode: str = "padding") -> Tensor: """Opposite of upscaling. This image downscales image to a desired size. Args: image (Tensor): Input image size (Tuple): Size to which image is down scaled. mode (str, optional): Downscaling mode. Defaults to "padding". Examples: >>> x = torch.rand(1, 3, 512, 512) >>> y = upscale_image(image, upscale_size=(555, 555), mode="padding") >>> y = downscale_image(y, size=(512, 512), mode='padding') >>> torch.allclose(x, y) True Returns: Tensor: Downscaled image """ input_h, input_w = size if mode == "padding": image = image[:, :, :input_h, :input_w] else: image = F.interpolate(input=image, size=(input_h, input_w)) return image class Tiler: """Tile Image into (non)overlapping Patches. Images are tiled in order to efficiently process large images. Args: tile_size: Tile dimension for each patch stride: Stride length between patches remove_border_count: Number of border pixels to be removed from tile before untiling mode: Upscaling mode for image resize.Supported formats: padding, interpolation Examples: >>> import torch >>> from torchvision import transforms >>> from skimage.data import camera >>> tiler = Tiler(tile_size=256,stride=128) >>> image = transforms.ToTensor()(camera()) >>> tiles = tiler.tile(image) >>> image.shape, tiles.shape (torch.Size([3, 512, 512]), torch.Size([9, 3, 256, 256])) >>> # Perform your operations on the tiles. >>> # Untile the patches to reconstruct the image >>> reconstructed_image = tiler.untile(tiles) >>> reconstructed_image.shape torch.Size([1, 3, 512, 512]) """ def __init__( self, tile_size: Union[int, Sequence], stride: Union[int, Sequence], remove_border_count: int = 0, mode: str = "padding", tile_count: int = 4, ) -> None: self.tile_size_h, self.tile_size_w = self.__validate_size_type(tile_size) self.tile_count = tile_count self.stride_h, self.stride_w = self.__validate_size_type(stride) self.remove_border_count = int(remove_border_count) self.overlapping = not (self.stride_h == self.tile_size_h and self.stride_w == self.tile_size_w) self.mode = mode if self.stride_h > self.tile_size_h or self.stride_w > self.tile_size_w: raise StrideSizeError( "Larger stride size than kernel size produces unreliable tiling results. " "Please ensure stride size is less than or equal than tiling size." ) if self.mode not in ["padding", "interpolation"]: raise ValueError(f"Unknown tiling mode {self.mode}. Available modes are padding and interpolation") self.batch_size: int self.num_channels: int self.input_h: int self.input_w: int self.pad_h: int self.pad_w: int self.resized_h: int self.resized_w: int self.num_patches_h: int self.num_patches_w: int @staticmethod def __validate_size_type(parameter: Union[int, Sequence]) -> Tuple[int, ...]: if isinstance(parameter, int): output = (parameter, parameter) elif isinstance(parameter, Sequence): output = (parameter[0], parameter[1]) else: raise ValueError(f"Unknown type {type(parameter)} for tile or stride size. Could be int or Sequence type.") if len(output) != 2: raise ValueError(f"Length of the size type must be 2 for height and width. Got {len(output)} instead.") return output def __random_tile(self, image: Tensor) -> Tensor: """Randomly crop tiles from the given image. Args: image: input image to be cropped Returns: Randomly cropped tiles from the image """ return torch.vstack([T.RandomCrop(self.tile_size_h)(image) for i in range(self.tile_count)]) def __unfold(self, tensor: Tensor) -> Tensor: """Unfolds tensor into tiles. This is the core function to perform tiling operation. Args: tensor: Input tensor from which tiles are generated. Returns: Generated tiles """ # identify device type based on input tensor device = tensor.device # extract and calculate parameters batch, channels, image_h, image_w = tensor.shape self.num_patches_h = int((image_h - self.tile_size_h) / self.stride_h) + 1 self.num_patches_w = int((image_w - self.tile_size_w) / self.stride_w) + 1 # create an empty torch tensor for output tiles = torch.zeros( (self.num_patches_h, self.num_patches_w, batch, channels, self.tile_size_h, self.tile_size_w), device=device ) # fill-in output tensor with spatial patches extracted from the image for (tile_i, tile_j), (loc_i, loc_j) in zip( product(range(self.num_patches_h), range(self.num_patches_w)), product( range(0, image_h - self.tile_size_h + 1, self.stride_h), range(0, image_w - self.tile_size_w + 1, self.stride_w), ), ): tiles[tile_i, tile_j, :] = tensor[ :, :, loc_i : (loc_i + self.tile_size_h), loc_j : (loc_j + self.tile_size_w) ] # rearrange the tiles in order [tile_count * batch, channels, tile_height, tile_width] tiles = tiles.permute(2, 0, 1, 3, 4, 5) tiles = tiles.contiguous().view(-1, channels, self.tile_size_h, self.tile_size_w) return tiles def __fold(self, tiles: Tensor) -> Tensor: """Fold the tiles back into the original tensor. This is the core method to reconstruct the original image from its tiled version. Args: tiles: Tiles from the input image, generated via __unfold method. Returns: Output that is the reconstructed version of the input tensor. """ # number of channels differs between image and anomaly map, so infer from input tiles. _, num_channels, tile_size_h, tile_size_w = tiles.shape scale_h, scale_w = (tile_size_h / self.tile_size_h), (tile_size_w / self.tile_size_w) # identify device type based on input tensor device = tiles.device # calculate tile size after borders removed reduced_tile_h = tile_size_h - (2 * self.remove_border_count) reduced_tile_w = tile_size_w - (2 * self.remove_border_count) # reconstructed image dimension image_size = (self.batch_size, num_channels, int(self.resized_h * scale_h), int(self.resized_w * scale_w)) # rearrange input tiles in format [tile_count, batch, channel, tile_h, tile_w] tiles = tiles.contiguous().view( self.batch_size, self.num_patches_h, self.num_patches_w, num_channels, tile_size_h, tile_size_w, ) tiles = tiles.permute(0, 3, 1, 2, 4, 5) tiles = tiles.contiguous().view(self.batch_size, num_channels, -1, tile_size_h, tile_size_w) tiles = tiles.permute(2, 0, 1, 3, 4) # remove tile borders by defined count tiles = tiles[ :, :, :, self.remove_border_count : reduced_tile_h + self.remove_border_count, self.remove_border_count : reduced_tile_w + self.remove_border_count, ] # create tensors to store intermediate results and outputs img = torch.zeros(image_size, device=device) lookup = torch.zeros(image_size, device=device) ones = torch.ones(reduced_tile_h, reduced_tile_w, device=device) # reconstruct image by adding patches to their respective location and # create a lookup for patch count in every location for patch, (loc_i, loc_j) in zip( tiles, product( range( self.remove_border_count, int(self.resized_h * scale_h) - reduced_tile_h + 1, int(self.stride_h * scale_h), ), range( self.remove_border_count, int(self.resized_w * scale_w) - reduced_tile_w + 1, int(self.stride_w * scale_w), ), ), ): img[:, :, loc_i : (loc_i + reduced_tile_h), loc_j : (loc_j + reduced_tile_w)] += patch lookup[:, :, loc_i : (loc_i + reduced_tile_h), loc_j : (loc_j + reduced_tile_w)] += ones # divide the reconstucted image by the lookup to average out the values img = torch.divide(img, lookup) # alternative way of removing nan values (isnan not supported by openvino) img[img != img] = 0 # pylint: disable=comparison-with-itself return img def tile(self, image: Tensor, use_random_tiling: Optional[bool] = False) -> Tensor: """Tiles an input image to either overlapping, non-overlapping or random patches. Args: image: Input image to tile. Examples: >>> from anomalib.data.tiler import Tiler >>> tiler = Tiler(tile_size=512,stride=256) >>> image = torch.rand(size=(2, 3, 1024, 1024)) >>> image.shape torch.Size([2, 3, 1024, 1024]) >>> tiles = tiler.tile(image) >>> tiles.shape torch.Size([18, 3, 512, 512]) Returns: Tiles generated from the image. """ if image.dim() == 3: image = image.unsqueeze(0) self.batch_size, self.num_channels, self.input_h, self.input_w = image.shape if self.input_h < self.tile_size_h or self.input_w < self.tile_size_w: raise ValueError( f"One of the edges of the tile size {self.tile_size_h, self.tile_size_w} " "is larger than that of the image {self.input_h, self.input_w}." ) self.resized_h, self.resized_w = compute_new_image_size( image_size=(self.input_h, self.input_w), tile_size=(self.tile_size_h, self.tile_size_w), stride=(self.stride_h, self.stride_w), ) image = upscale_image(image, size=(self.resized_h, self.resized_w), mode=self.mode) if use_random_tiling: image_tiles = self.__random_tile(image) else: image_tiles = self.__unfold(image) return image_tiles def untile(self, tiles: Tensor) -> Tensor: """Untiles patches to reconstruct the original input image. If patches, are overlapping patches, the function averages the overlapping pixels, and return the reconstructed image. Args: tiles: Tiles from the input image, generated via tile().. Examples: >>> from anomalib.datasets.tiler import Tiler >>> tiler = Tiler(tile_size=512,stride=256) >>> image = torch.rand(size=(2, 3, 1024, 1024)) >>> image.shape torch.Size([2, 3, 1024, 1024]) >>> tiles = tiler.tile(image) >>> tiles.shape torch.Size([18, 3, 512, 512]) >>> reconstructed_image = tiler.untile(tiles) >>> reconstructed_image.shape torch.Size([2, 3, 1024, 1024]) >>> torch.equal(image, reconstructed_image) True Returns: Output that is the reconstructed version of the input tensor. """ image = self.__fold(tiles) image = downscale_image(image=image, size=(self.input_h, self.input_w), mode=self.mode) return image