Spaces:
				
			
			
	
			
			
		No application file
		
	
	
	
			
			
	
	
	
	
		
		
		No application file
		
	| import collections.abc as collections | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import cv2 | |
| import kornia | |
| import numpy as np | |
| import torch | |
| from omegaconf import OmegaConf | |
| class ImagePreprocessor: | |
| default_conf = { | |
| "resize": None, # target edge length, None for no resizing | |
| "edge_divisible_by": None, | |
| "side": "long", | |
| "interpolation": "bilinear", | |
| "align_corners": None, | |
| "antialias": True, | |
| "square_pad": False, | |
| "add_padding_mask": False, | |
| } | |
| def __init__(self, conf) -> None: | |
| super().__init__() | |
| default_conf = OmegaConf.create(self.default_conf) | |
| OmegaConf.set_struct(default_conf, True) | |
| self.conf = OmegaConf.merge(default_conf, conf) | |
| def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict: | |
| """Resize and preprocess an image, return image and resize scale""" | |
| h, w = img.shape[-2:] | |
| size = h, w | |
| if self.conf.resize is not None: | |
| if interpolation is None: | |
| interpolation = self.conf.interpolation | |
| size = self.get_new_image_size(h, w) | |
| img = kornia.geometry.transform.resize( | |
| img, | |
| size, | |
| side=self.conf.side, | |
| antialias=self.conf.antialias, | |
| align_corners=self.conf.align_corners, | |
| interpolation=interpolation, | |
| ) | |
| scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) | |
| T = np.diag([scale[0], scale[1], 1]) | |
| data = { | |
| "scales": scale, | |
| "image_size": np.array(size[::-1]), | |
| "transform": T, | |
| "original_image_size": np.array([w, h]), | |
| } | |
| if self.conf.square_pad: | |
| sl = max(img.shape[-2:]) | |
| data["image"] = torch.zeros( | |
| *img.shape[:-2], sl, sl, device=img.device, dtype=img.dtype | |
| ) | |
| data["image"][:, : img.shape[-2], : img.shape[-1]] = img | |
| if self.conf.add_padding_mask: | |
| data["padding_mask"] = torch.zeros( | |
| *img.shape[:-3], 1, sl, sl, device=img.device, dtype=torch.bool | |
| ) | |
| data["padding_mask"][:, : img.shape[-2], : img.shape[-1]] = True | |
| else: | |
| data["image"] = img | |
| return data | |
| def load_image(self, image_path: Path) -> dict: | |
| return self(load_image(image_path)) | |
| def get_new_image_size( | |
| self, | |
| h: int, | |
| w: int, | |
| ) -> Tuple[int, int]: | |
| side = self.conf.side | |
| if isinstance(self.conf.resize, collections.Iterable): | |
| assert len(self.conf.resize) == 2 | |
| return tuple(self.conf.resize) | |
| side_size = self.conf.resize | |
| aspect_ratio = w / h | |
| if side not in ("short", "long", "vert", "horz"): | |
| raise ValueError( | |
| f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'" | |
| ) | |
| if side == "vert": | |
| size = side_size, int(side_size * aspect_ratio) | |
| elif side == "horz": | |
| size = int(side_size / aspect_ratio), side_size | |
| elif (side == "short") ^ (aspect_ratio < 1.0): | |
| size = side_size, int(side_size * aspect_ratio) | |
| else: | |
| size = int(side_size / aspect_ratio), side_size | |
| if self.conf.edge_divisible_by is not None: | |
| df = self.conf.edge_divisible_by | |
| size = list(map(lambda x: int(x // df * df), size)) | |
| return size | |
| def read_image(path: Path, grayscale: bool = False) -> np.ndarray: | |
| """Read an image from path as RGB or grayscale""" | |
| if not Path(path).exists(): | |
| raise FileNotFoundError(f"No image at path {path}.") | |
| mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR | |
| image = cv2.imread(str(path), mode) | |
| if image is None: | |
| raise IOError(f"Could not read image at {path}.") | |
| if not grayscale: | |
| image = image[..., ::-1] | |
| return image | |
| def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: | |
| """Normalize the image tensor and reorder the dimensions.""" | |
| if image.ndim == 3: | |
| image = image.transpose((2, 0, 1)) # HxWxC to CxHxW | |
| elif image.ndim == 2: | |
| image = image[None] # add channel axis | |
| else: | |
| raise ValueError(f"Not an image: {image.shape}") | |
| return torch.tensor(image / 255.0, dtype=torch.float) | |
| def load_image(path: Path, grayscale=False) -> torch.Tensor: | |
| image = read_image(path, grayscale=grayscale) | |
| return numpy_image_to_torch(image) | |
