File size: 7,342 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import random
import numpy as np
import torch
from einops import rearrange
from jaxtyping import Float
from PIL import Image
from torch import Tensor
import torchvision.transforms.functional as F
import cv2
from ..types import AnyExample, AnyViews
def rescale(
image: Float[Tensor, "3 h_in w_in"],
shape: tuple[int, int],
) -> Float[Tensor, "3 h_out w_out"]:
h, w = shape
image_new = (image * 255).clip(min=0, max=255).type(torch.uint8)
image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy()
image_new = Image.fromarray(image_new)
image_new = image_new.resize((w, h), Image.LANCZOS)
image_new = np.array(image_new) / 255
image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device)
return rearrange(image_new, "h w c -> c h w")
def rescale_depth(
depth: Float[Tensor, "1 h w"],
shape: tuple[int, int],
) -> Float[Tensor, "1 h_out w_out"]:
h, w = shape
depth_new = depth.detach().cpu().numpy()
depth_new = cv2.resize(depth_new, (w,h), interpolation=cv2.INTER_NEAREST)
depth_new = torch.from_numpy(depth_new).to(depth.device)
return depth_new
def center_crop(
images: Float[Tensor, "*#batch c h w"],
intrinsics: Float[Tensor, "*#batch 3 3"],
shape: tuple[int, int],
depths: Float[Tensor, "*#batch 1 h w"] | None = None,
) -> tuple[
Float[Tensor, "*#batch c h_out w_out"], # updated images
Float[Tensor, "*#batch 3 3"], # updated intrinsics
Float[Tensor, "*#batch 1 h_out w_out"] | None, # updated depths
]:
*_, h_in, w_in = images.shape
h_out, w_out = shape
# Note that odd input dimensions induce half-pixel misalignments.
row = (h_in - h_out) // 2
col = (w_in - w_out) // 2
# Center-crop the image.
images = images[..., :, row : row + h_out, col : col + w_out]
if depths is not None:
depths = depths[..., row : row + h_out, col : col + w_out]
# Adjust the intrinsics to account for the cropping.
intrinsics = intrinsics.clone()
intrinsics[..., 0, 0] *= w_in / w_out # fx
intrinsics[..., 1, 1] *= h_in / h_out # fy
if depths is not None:
return images, intrinsics, depths
else:
return images, intrinsics
def rescale_and_crop(
images: Float[Tensor, "*#batch c h w"],
intrinsics: Float[Tensor, "*#batch 3 3"],
shape: tuple[int, int],
intr_aug: bool = False,
scale_range: tuple[float, float] = (0.77, 1.0),
depths: Float[Tensor, "*#batch 1 h w"] | None = None,
) -> tuple[
Float[Tensor, "*#batch c h_out w_out"], # updated images
Float[Tensor, "*#batch 3 3"], # updated intrinsics
Float[Tensor, "*#batch 1 h_out w_out"] | None, # updated depths
]:
if type(images) == list:
images_new = []
intrinsics_new = []
for i in range(len(images)):
image = images[i]
intrinsic = intrinsics[i]
*_, h_in, w_in = image.shape
h_out, w_out = shape
scale_factor = max(h_out / h_in, w_out / w_in)
h_scaled = round(h_in * scale_factor)
w_scaled = round(w_in * scale_factor)
image = F.resize(image, (h_scaled, w_scaled))
image = F.center_crop(image, (h_out, w_out))
images_new.append(image)
intrinsic_new = intrinsic.clone()
intrinsic_new[..., 0, 0] *= w_scaled / w_in # fx
intrinsic_new[..., 1, 1] *= h_scaled / h_in # fy
intrinsics_new.append(intrinsic_new)
if depths is not None:
depths_new = []
for i in range(len(depths)):
depth = depths[i]
depth = rescale_depth(depth, (h_out, w_out))
depth = F.center_crop(depth, (h_out, w_out))
depths_new.append(depth)
return torch.stack(images_new), torch.stack(intrinsics_new), torch.stack(depths_new)
else:
return torch.stack(images_new), torch.stack(intrinsics_new)
else:
# we only support intr_aug for clean datasets
*_, h_in, w_in = images.shape
h_out, w_out = shape
# assert h_out <= h_in and w_out <= w_in # to avoid the case that the image is too small, like co3d
if intr_aug:
scale = random.uniform(*scale_range)
h_scale = round(h_out * scale)
w_scale = round(w_out * scale)
else:
h_scale = h_out
w_scale = w_out
scale_factor = max(h_scale / h_in, w_scale / w_in)
h_scaled = round(h_in * scale_factor)
w_scaled = round(w_in * scale_factor)
assert h_scaled == h_scale or w_scaled == w_scale
# Reshape the images to the correct size. Assume we don't have to worry about
# changing the intrinsics based on how the images are rounded.
*batch, c, h, w = images.shape
images = images.reshape(-1, c, h, w)
images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images])
images = images.reshape(*batch, c, h_scaled, w_scaled)
if depths is not None:
if type(depths) == list:
depths_new = []
for i in range(len(depths)):
depth = depths[i]
depth = rescale_depth(depth, (h_scaled, w_scaled))
depths_new.append(depth)
depths = torch.stack(depths_new)
else:
depths = depths.reshape(-1, h, w)
depths = torch.stack([rescale_depth(depth, (h_scaled, w_scaled)) for depth in depths])
depths = depths.reshape(*batch, h_scaled, w_scaled)
images, intrinsics, depths = center_crop(images, intrinsics, (h_scale, w_scale), depths)
if intr_aug:
images = F.resize(images, size=(h_out, w_out), interpolation=F.InterpolationMode.BILINEAR)
depths = F.resize(depths, size=(h_out, w_out), interpolation=F.InterpolationMode.NEAREST)
return images, intrinsics, depths
else:
images, intrinsics = center_crop(images, intrinsics, (h_scale, w_scale))
if intr_aug:
images = F.resize(images, size=(h_out, w_out))
return images, intrinsics
def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int], intr_aug: bool = False) -> AnyViews:
if "depth" in views.keys():
images, intrinsics, depths = rescale_and_crop(views["image"], views["intrinsics"], shape, depths=views["depth"], intr_aug=intr_aug)
return {
**views,
"image": images,
"intrinsics": intrinsics,
"depth": depths,
}
else:
images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape, intr_aug)
return {
**views,
"image": images,
"intrinsics": intrinsics,
}
def apply_crop_shim(example: AnyExample, shape: tuple[int, int], intr_aug: bool = False) -> AnyExample:
"""Crop images in the example."""
return {
**example,
"context": apply_crop_shim_to_views(example["context"], shape, intr_aug),
"target": apply_crop_shim_to_views(example["target"], shape, intr_aug),
}
|