# Copyright (c) Meta Platforms, Inc. and affiliates. import collections from typing import Callable, Optional, Sequence, Union import numpy as np import torch import torchvision.transforms.functional as tvf from scipy.spatial.transform import Rotation from ..utils.geometry import from_homogeneous, to_homogeneous from ..utils.wrappers import Camera def rectify_image( image: torch.Tensor, cam: Camera, roll: float, pitch: Optional[float] = None, valid: Optional[torch.Tensor] = None, ): *_, h, w = image.shape grid = torch.meshgrid( [torch.arange(w, device=image.device), torch.arange(h, device=image.device)], indexing="xy", ) grid = torch.stack(grid, -1).to(image.dtype) if pitch is not None: args = ("ZX", (roll, pitch)) else: args = ("Z", roll) R = Rotation.from_euler(*args, degrees=True).as_matrix() R = torch.from_numpy(R).to(image) grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T grid_rect = cam.denormalize(from_homogeneous(grid_rect)) grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1 rectified = torch.nn.functional.grid_sample( image[None], grid_norm[None], align_corners=False, mode="bilinear", ).squeeze(0) if valid is None: valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1) else: valid = ( torch.nn.functional.grid_sample( valid[None, None].float(), grid_norm[None], align_corners=False, mode="nearest", )[0, 0] > 0 ) return rectified, valid def resize_image( image: torch.Tensor, size: Union[int, Sequence, np.ndarray], fn: Optional[Callable] = None, camera: Optional[Camera] = None, valid: np.ndarray = None, ): """Resize an image to a fixed size, or according to max or min edge.""" *_, h, w = image.shape if fn is not None: assert isinstance(size, int) scale = size / fn(h, w) h_new, w_new = (int(round(x * scale)) for x in (h, w)) scale = (scale, scale) else: if isinstance(size, (collections.abc.Sequence, np.ndarray)): w_new, h_new = (int(x) for x in size) elif isinstance(size, int): w_new = h_new = size else: raise ValueError(f"Incorrect new size: {size}") scale = (w_new / w, h_new / h) if (w, h) != (w_new, h_new): mode = tvf.InterpolationMode.BILINEAR image = tvf.resize(image, (h_new, w_new), interpolation=mode, antialias=True) image.clip_(0, 1) if camera is not None: camera = camera.scale(scale) if valid is not None: valid = tvf.resize( valid.unsqueeze(0), (h_new, w_new), interpolation=tvf.InterpolationMode.NEAREST, ).squeeze(0) ret = [image, scale] if camera is not None: ret.append(camera) if valid is not None: ret.append(valid) return ret def pad_image( image: torch.Tensor, size: Union[int, Sequence, np.ndarray], camera: Optional[Camera] = None, valid: torch.Tensor = None, crop_and_center: bool = False, ): if isinstance(size, int): w_new = h_new = size elif isinstance(size, (collections.abc.Sequence, np.ndarray)): w_new, h_new = size else: raise ValueError(f"Incorrect new size: {size}") *c, h, w = image.shape if crop_and_center: diff = np.array([w - w_new, h - h_new]) left, top = left_top = np.round(diff / 2).astype(int) right, bottom = diff - left_top else: assert h <= h_new assert w <= w_new top = bottom = left = right = 0 slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)] slice_in = np.s_[ ..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0) ] if (w, h) == (w_new, h_new): out = image else: out = torch.zeros((*c, h_new, w_new), dtype=image.dtype) out[slice_out] = image[slice_in] if camera is not None: camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new)) out_valid = torch.zeros((h_new, w_new), dtype=torch.bool) out_valid[slice_out] = True if valid is None else valid[slice_in] if camera is not None: return out, out_valid, camera else: return out, out_valid