Spaces:
Build error
Build error
"""Image Tiling Tests.""" | |
import pytest | |
import torch | |
from omegaconf import ListConfig | |
from anomalib.pre_processing.tiler import StrideSizeError, Tiler | |
tile_data = [ | |
([3, 1024, 1024], 512, 512, torch.Size([4, 3, 512, 512]), False), | |
([1, 3, 1024, 1024], 512, 512, torch.Size([4, 3, 512, 512]), False), | |
([3, 1024, 1024], 512, 512, torch.Size([4, 3, 512, 512]), True), | |
([1, 3, 1024, 1024], 512, 512, torch.Size([4, 3, 512, 512]), True), | |
] | |
untile_data = [ | |
([3, 1024, 1024], 512, 256, torch.Size([4, 3, 512, 512])), | |
([1, 3, 1024, 1024], 512, 512, torch.Size([4, 3, 512, 512])), | |
] | |
overlapping_data = [ | |
( | |
torch.Size([1, 3, 1024, 1024]), | |
512, | |
256, | |
torch.Size([16, 3, 512, 512]), | |
"padding", | |
), | |
( | |
torch.Size([1, 3, 1024, 1024]), | |
512, | |
256, | |
torch.Size([16, 3, 512, 512]), | |
"interpolation", | |
), | |
] | |
def test_size_types_should_be_int_tuple_or_list_config(tile_size, stride): | |
"""Size type could only be integer, tuple or ListConfig type.""" | |
tiler = Tiler(tile_size=tile_size, stride=stride) | |
assert isinstance(tiler.tile_size_h, int) | |
assert isinstance(tiler.stride_w, int) | |
def test_tiler_handles_single_image_without_batch_dimension(image_size, tile_size, stride, shape, use_random_tiling): | |
"""Tiler should add batch dimension if image is 3D (CxHxW).""" | |
tiler = Tiler(tile_size=tile_size, stride=stride) | |
image = torch.rand(image_size) | |
patches = tiler.tile(image, use_random_tiling=use_random_tiling) | |
assert patches.shape == shape | |
def test_stride_size_cannot_be_larger_than_tile_size(): | |
"""Larger stride size than tile size is not desired, and causes issues.""" | |
kernel_size = (128, 128) | |
stride = 256 | |
with pytest.raises(StrideSizeError): | |
tiler = Tiler(tile_size=kernel_size, stride=stride) | |
def test_tile_size_cannot_be_larger_than_image_size(): | |
"""Larger tile size than image size is not desired, and causes issues.""" | |
with pytest.raises(ValueError): | |
tiler = Tiler(tile_size=1024, stride=512) | |
image = torch.rand(1, 3, 512, 512) | |
tiler.tile(image) | |
def test_untile_non_overlapping_patches(tile_size, kernel_size, stride, image_size): | |
"""Non-Overlapping Tiling/Untiling should return the same image size.""" | |
tiler = Tiler(tile_size=kernel_size, stride=stride) | |
image = torch.rand(image_size) | |
tiles = tiler.tile(image) | |
untiled_image = tiler.untile(tiles) | |
assert untiled_image.shape == torch.Size(image_size) | |
def test_upscale_downscale_mode(mode): | |
with pytest.raises(ValueError): | |
tiler = Tiler(tile_size=(512, 512), stride=(256, 256), mode=mode) | |
def test_untile_overlapping_patches(image_size, kernel_size, stride, remove_border_count, tile_size, mode): | |
"""Overlapping Tiling/Untiling should return the same image size.""" | |
tiler = Tiler( | |
tile_size=kernel_size, | |
stride=stride, | |
remove_border_count=remove_border_count, | |
mode=mode, | |
) | |
image = torch.rand(image_size) | |
tiles = tiler.tile(image) | |
reconstructed_image = tiler.untile(tiles) | |
image = image[ | |
:, | |
:, | |
remove_border_count:-remove_border_count, | |
remove_border_count:-remove_border_count, | |
] | |
reconstructed_image = reconstructed_image[ | |
:, | |
:, | |
remove_border_count:-remove_border_count, | |
remove_border_count:-remove_border_count, | |
] | |
assert torch.equal(image, reconstructed_image) | |
def test_divisible_tile_size_and_stride(image_size, tile_size, stride, mode): | |
"""When the image is not divisible by tile size and stride, Tiler should up | |
samples the image before tiling, and downscales before untiling.""" | |
tiler = Tiler(tile_size, stride, mode=mode) | |
image = torch.rand(image_size) | |
tiles = tiler.tile(image) | |
reconstructed_image = tiler.untile(tiles) | |
assert image.shape == reconstructed_image.shape | |
if mode == "padding": | |
assert torch.allclose(image, reconstructed_image) | |