import math import random from typing import Tuple, Optional import kornia import torch from torch import nn, Tensor import torch.nn.functional as F import torchvision.transforms as tf from scenedino.models.backbones.dino.decoder import NoDecoder import logging logger = logging.getLogger("training") class MultiScaleCropGT_kornia(nn.Module): """This class implements multi-scale-crop augmentation for DINO features.""" def __init__( self, gt_encoder: nn.Module, num_views: int = 8, image_size: Tuple[int, int] = (192, 640), feature_stride: int = 16, ) -> None: """Constructor method. Args: num_views (int): Number of view per image. Default 8. augmentations (Tuple[AugmentationBase2D, ...]): Geometric augmentations to be applied. feature_stride (int): Stride of the features. Default 16. """ # Call super constructor super(MultiScaleCropGT_kornia, self).__init__() # GT encoder self.gt_encoder = gt_encoder # Save parameters self.augmentations_per_sample: int = num_views self.feature_stride: int = feature_stride # Init augmentations image_ratio = image_size[0] / image_size[1] augmentations = ( kornia.augmentation.RandomHorizontalFlip(p=0.5), #kornia.augmentation.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), kornia.augmentation.RandomResizedCrop( scale=(0.5, 1.0), size=tuple(image_size), ratio=(image_ratio/1.2, image_ratio*1.2), p=1.0 # Here you need to set your resolution ), ) self.augmentations: nn.Module = kornia.augmentation.VideoSequential(*augmentations, same_on_frame=True) @staticmethod def _affine_transform_valid_pixels(transform: Tensor, mask: Tensor) -> Tensor: """Applies affine transform to a mask of ones to estimate valid pixels. Args: transform (Tensor): Affine transform of the shape [B, 3, 3] mask (Tensor): Mask of the shape [B, 1, H, W]. Returns: valid_pixels (Tensor): Mask of valid pixels of the shape [B, 1, H, W]. """ # Get shape H, W = mask.shape[2:] # type: int, int # Resample mask map valid_pixels: Tensor = kornia.geometry.warp_perspective( mask, transform, (H, W), mode="nearest", ) # Threshold mask valid_pixels = torch.where( # type: ignore valid_pixels > 0.999, torch.ones_like(valid_pixels), torch.zeros_like(valid_pixels) ) return valid_pixels def _accumulate_predictions(self, features: Tensor, transforms: Tensor) -> Tensor: """Accumulates features over multiple predictions. Args: features (Tensor): Feature predictions of the shape [B, num_views, H, W]. transforms (Tensor): Affine transformations of the shape [B, num_views, 3, 3]. Returns: optical_flow_predictions_accumulated (Tensor): Accumulated optical flow of the shape [B, 2, H, W]. """ # Get shape B, N, C, H, W = features.shape # type: int, int, int, int, int # Get base and augmented views features_base = features[:, -2:] features_augmented = features[:, :-2] # Combine batch dimension and view dimension features_augmented = features_augmented.flatten(0, 1) transforms = transforms.flatten(0, 1) # Rescale transformation transforms[:, 0, -1] = transforms[:, 0, -1] #/ float(self.feature_stride) transforms[:, 1, -1] = transforms[:, 1, -1] #/ float(self.feature_stride) # Invert transformations transforms_inv: Tensor = torch.inverse(transforms) # Resample optical flow map features_resampled: Tensor = kornia.geometry.warp_perspective( features_augmented, transforms_inv, (H, W), mode="bilinear", ) # Separate batch and view dimension again features_resampled = features_resampled.reshape(B, -1, C, H, W) # Add base views features_resampled = torch.cat((features_resampled, features_base), dim=1) # Reverse flip features_resampled[:, -2] = features_resampled[:, -2].flip(dims=(-1,)) # Compute valid pixels mask: Tensor = torch.ones( B, N - 2, 1, H, W, dtype=features_resampled.dtype, device=features_resampled.device ) mask = mask.flatten(0, 1) valid_pixels: Tensor = self._affine_transform_valid_pixels(transforms_inv, mask) valid_pixels = valid_pixels.reshape(B, N - 2, 1, H, W) valid_pixels = F.pad(valid_pixels, (0, 0, 0, 0, 0, 0, 0, 2), value=1) # Set invalid flow vectors to zero features_resampled[valid_pixels.repeat(1, 1, C, 1, 1) == 0.0] = torch.nan # Average optical flow over different views given the sum valid pixels for the specific pixel # logger.info(features_resampled.shape) return features_resampled.nanmean(dim=1) def _get_augmentations(self, images: Tensor) -> Tuple[Tensor, Tensor]: """Forward pass generates different augmentations of the input images. Args: images (Tensor): Images of the shape [B, 3, H, W] Returns: images_augmented (Tensor): Augmented images of the shape [B, N, H, W]. transforms (Tensor): Transformations of the shape [B, N, 3, 3]. """ # Add dummy dimension shape is [B, num_views, 3, H, W] images = images[:, None] # Init tensor to store transformations transformations: Tensor = torch.empty( images.shape[0], self.augmentations_per_sample - 2, 3, 3, dtype=torch.float32, device=images.device ) # Init tensor to store augmented images images_augmented: Tensor = torch.empty_like(images) images_augmented = images_augmented[:, None].repeat_interleave(self.augmentations_per_sample, dim=1) # Save original and flipped images images_augmented[:, -1] = images.clone() images_augmented[:, -2] = images.clone().flip(dims=(-1,)) # Apply geometric augmentations for index in range(images.shape[0]): images_repeated: Tensor = images[index][None].repeat_interleave(self.augmentations_per_sample - 2, dim=0) images_augmented[index, :-2] = self.augmentations(images_repeated) transformations[index] = self.augmentations.get_transformation_matrix( images_repeated, self.augmentations._params ) return images_augmented[:, :, 0], transformations def forward_chunk(self, images): batch_size, _, h, w = images.shape # Perform augmentation images_aug, transformations = self._get_augmentations(images) # Get representations features = self.gt_encoder(images_aug.flatten(0, 1))[-1] features = F.interpolate(features, size=(h, w), mode="bilinear") # features = features.repeat_interleave(self.feature_stride, -1).repeat_interleave(self.feature_stride, -2) _, dino_dim, _, _ = features.shape features = features.view(batch_size, -1, dino_dim, h, w) chunks = torch.chunk(features, chunks=4, dim=2) # Split into 4 parts along dim=3 chunks = [self._accumulate_predictions(chunk, transformations) for chunk in chunks] features_accumulated = torch.cat(chunks, dim=1) # features_accumulated = self._accumulate_predictions(features, transformations) return features_accumulated / torch.linalg.norm(features_accumulated, dim=1, keepdim=True) def forward(self, images): max_chunk = 16 aug_no_images = images.shape[0] * self.augmentations_per_sample if aug_no_images > max_chunk: no_chunks = aug_no_images // max_chunk images = torch.chunk(images, no_chunks) features = [self.forward_chunk(image) for image in images] features = torch.cat(features, dim=0) return [features] else: return [self.forward_chunk(images)] class InterpolatedGT(nn.Module): def __init__(self, arch: str, gt_encoder: nn.Module, image_size: Tuple[int, int]): super().__init__() self.upsampler = NoDecoder(image_size, arch, normalize_features=False) self.gt_encoder = gt_encoder def forward(self, x): gt_patches = self.gt_encoder(x) return self.upsampler(gt_patches) def _get_affine(params, crop_size, batch_size): # construct affine operator affine = torch.zeros(batch_size, 2, 3) aspect_ratio = float(crop_size[0]) / float(crop_size[1]) for i, (dy, dx, alpha, scale, flip) in enumerate(params): # R inverse sin = math.sin(alpha * math.pi / 180.) cos = math.cos(alpha * math.pi / 180.) # inverse, note how flipping is incorporated affine[i, 0, 0], affine[i, 0, 1] = flip * cos, sin * aspect_ratio affine[i, 1, 0], affine[i, 1, 1] = -sin / aspect_ratio, cos # T inverse Rinv * t == R^T * t affine[i, 0, 2] = -1. * (cos * dx + sin * dy) affine[i, 1, 2] = -1. * (-sin * dx + cos * dy) # T affine[i, 0, 2] /= float(crop_size[1] // 2) affine[i, 1, 2] /= float(crop_size[0] // 2) # scaling affine[i] *= scale return affine class MultiScaleCropGT(nn.Module): def __init__(self, gt_encoder: nn.Module, num_views: int, scale_from: float = 0.4, grid_sample_batch: Optional[int] = 96): super().__init__() self.gt_encoder = gt_encoder self.num_views = num_views self.augmentation = MaskRandScaleCrop(scale_from) self.grid_sample_batch = grid_sample_batch def forward(self, x): result = None count = 0 batch_size, _, h, w = x.shape for i in range(self.num_views): if i > 0: x, params = self.augmentation(x) else: params = [[0., 0., 0., 1., 1.] for _ in range(x.shape[0])] gt_patches = self.gt_encoder(x)[-1] affine = _get_affine(params, (h, w), batch_size).cuda() affine_grid_gt = F.affine_grid(affine, x.size(), align_corners=False) if self.grid_sample_batch: d = gt_patches.shape[1] assert d % self.grid_sample_batch == 0 for idx in range(0, d, self.grid_sample_batch): gt_aligned_batch = F.grid_sample(gt_patches[:, idx:idx+self.grid_sample_batch], affine_grid_gt, mode="bilinear", align_corners=False) if result is None: result = torch.zeros(batch_size, d, h, w, device="cuda") result[:, idx:idx+self.grid_sample_batch] += gt_aligned_batch else: gt_aligned = F.grid_sample(gt_patches, affine_grid_gt, mode="bilinear", align_corners=False) if result is None: result = 0 result += gt_aligned within_bounds_x = (affine_grid_gt[..., 0] >= -1) & (affine_grid_gt[..., 0] <= 1) within_bounds_y = (affine_grid_gt[..., 1] >= -1) & (affine_grid_gt[..., 1] <= 1) not_padded_mask = within_bounds_x & within_bounds_y count += not_padded_mask.unsqueeze(1) count[count == 0] = 1 return [result.div_(count)] class MaskRandScaleCrop(object): def __init__(self, scale_from): self.scale_from = scale_from def get_params(self, h, w): new_scale = random.uniform(self.scale_from, 1) new_h = int(new_scale * h) new_w = int(new_scale * w) i = random.randint(0, h - new_h) j = random.randint(0, w - new_w) flip = 1 if random.random() > 0.5 else -1 return i, j, new_h, new_w, new_scale, flip def __call__(self, images, affine=None): if affine is None: affine = [[0., 0., 0., 1., 1.] for _ in range(len(images))] _, H, W = images[0].shape i2 = H / 2 j2 = W / 2 for k, image in enumerate(images): ii, jj, h, w, s, flip = self.get_params(H, W) if s == 1.: continue # no change in scale # displacement of the centre dy = ii + h / 2 - i2 dx = jj + w / 2 - j2 affine[k][0] = dy affine[k][1] = dx affine[k][3] = 1 / s # affine[k][4] = flip assert ii >= 0 and jj >= 0 image_crop = tf.functional.crop(image, ii, jj, h, w) images[k] = tf.functional.resize(image_crop, (H, W), tf.InterpolationMode.BILINEAR) return images, affine