|
import copy |
|
import random |
|
import numpy as np |
|
import torch |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from ..types import AnyExample, AnyViews |
|
|
|
|
|
def reflect_extrinsics( |
|
extrinsics: Float[Tensor, "*batch 4 4"], |
|
) -> Float[Tensor, "*batch 4 4"]: |
|
reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) |
|
reflect[0, 0] = -1 |
|
return reflect @ extrinsics @ reflect |
|
|
|
|
|
def reflect_views(views: AnyViews) -> AnyViews: |
|
if "depth" in views.keys(): |
|
return { |
|
**views, |
|
"image": views["image"].flip(-1), |
|
"extrinsics": reflect_extrinsics(views["extrinsics"]), |
|
"depth": views["depth"].flip(-1), |
|
} |
|
else: |
|
return { |
|
**views, |
|
"image": views["image"].flip(-1), |
|
"extrinsics": reflect_extrinsics(views["extrinsics"]), |
|
} |
|
|
|
|
|
def apply_augmentation_shim( |
|
example: AnyExample, |
|
generator: torch.Generator | None = None, |
|
) -> AnyExample: |
|
"""Randomly augment the training images.""" |
|
|
|
if torch.rand(tuple(), generator=generator) < 0.5: |
|
return example |
|
|
|
return { |
|
**example, |
|
"context": reflect_views(example["context"]), |
|
"target": reflect_views(example["target"]), |
|
} |
|
|
|
def rotate_90_degrees( |
|
image: torch.Tensor, depth_map: torch.Tensor | None, extri_opencv: torch.Tensor, intri_opencv: torch.Tensor, clockwise=True |
|
): |
|
""" |
|
Rotates the input image, depth map, and camera parameters by 90 degrees. |
|
|
|
Applies one of two 90-degree rotations: |
|
- Clockwise |
|
- Counterclockwise (if clockwise=False) |
|
|
|
The extrinsic and intrinsic matrices are adjusted accordingly to maintain |
|
correct camera geometry. |
|
|
|
Args: |
|
image (torch.Tensor): |
|
Input image tensor of shape (C, H, W). |
|
depth_map (torch.Tensor or None): |
|
Depth map tensor of shape (H, W), or None if not available. |
|
extri_opencv (torch.Tensor): |
|
Extrinsic matrix (3x4) in OpenCV convention. |
|
intri_opencv (torch.Tensor): |
|
Intrinsic matrix (3x3). |
|
clockwise (bool): |
|
If True, rotates the image 90 degrees clockwise; else 90 degrees counterclockwise. |
|
|
|
Returns: |
|
tuple: |
|
( |
|
rotated_image, |
|
rotated_depth_map, |
|
new_extri_opencv, |
|
new_intri_opencv |
|
) |
|
|
|
Where each is the updated version after the rotation. |
|
""" |
|
image_height, image_width = image.shape[-2:] |
|
|
|
|
|
rotated_image, rotated_depth_map = rotate_image_and_depth_rot90(image, depth_map, clockwise) |
|
|
|
new_intri_opencv = adjust_intrinsic_matrix_rot90(intri_opencv, image_width, image_height, clockwise) |
|
|
|
new_extri_opencv = adjust_extrinsic_matrix_rot90(extri_opencv, clockwise) |
|
|
|
return ( |
|
rotated_image, |
|
rotated_depth_map, |
|
new_extri_opencv, |
|
new_intri_opencv, |
|
) |
|
|
|
|
|
def rotate_image_and_depth_rot90(image: torch.Tensor, depth_map: torch.Tensor | None, clockwise: bool): |
|
""" |
|
Rotates the given image and depth map by 90 degrees (clockwise or counterclockwise). |
|
|
|
Args: |
|
image (torch.Tensor): |
|
Input image tensor of shape (C, H, W). |
|
depth_map (torch.Tensor or None): |
|
Depth map tensor of shape (H, W), or None if not available. |
|
clockwise (bool): |
|
If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise. |
|
|
|
Returns: |
|
tuple: |
|
(rotated_image, rotated_depth_map) |
|
""" |
|
rotated_depth_map = None |
|
if clockwise: |
|
rotated_image = torch.rot90(image, k=-1, dims=[-2, -1]) |
|
if depth_map is not None: |
|
rotated_depth_map = torch.rot90(depth_map, k=-1, dims=[-2, -1]) |
|
else: |
|
rotated_image = torch.rot90(image, k=1, dims=[-2, -1]) |
|
if depth_map is not None: |
|
rotated_depth_map = torch.rot90(depth_map, k=1, dims=[-2, -1]) |
|
return rotated_image, rotated_depth_map |
|
|
|
|
|
def adjust_extrinsic_matrix_rot90(extri_opencv: torch.Tensor, clockwise: bool): |
|
""" |
|
Adjusts the extrinsic matrix (3x4) for a 90-degree rotation of the image. |
|
|
|
The rotation is in the image plane. This modifies the camera orientation |
|
accordingly. The function applies either a clockwise or counterclockwise |
|
90-degree rotation. |
|
|
|
Args: |
|
extri_opencv (torch.Tensor): |
|
Extrinsic matrix (3x4) in OpenCV convention. |
|
clockwise (bool): |
|
If True, rotate extrinsic for a 90-degree clockwise image rotation; |
|
otherwise, counterclockwise. |
|
|
|
Returns: |
|
torch.Tensor: |
|
A new 3x4 extrinsic matrix after the rotation. |
|
""" |
|
R = extri_opencv[:3, :3] |
|
t = extri_opencv[:3, 3] |
|
|
|
if clockwise: |
|
R_rotation = torch.tensor([ |
|
[0, -1, 0], |
|
[1, 0, 0], |
|
[0, 0, 1] |
|
], dtype=extri_opencv.dtype, device=extri_opencv.device) |
|
else: |
|
R_rotation = torch.tensor([ |
|
[0, 1, 0], |
|
[-1, 0, 0], |
|
[0, 0, 1] |
|
], dtype=extri_opencv.dtype, device=extri_opencv.device) |
|
|
|
new_R = torch.matmul(R_rotation, R) |
|
new_t = torch.matmul(R_rotation, t) |
|
new_extri_opencv = torch.cat((new_R, new_t.reshape(-1, 1)), dim=1) |
|
new_extri_opencv = torch.cat((new_extri_opencv, |
|
torch.tensor([[0, 0, 0, 1]], |
|
dtype=extri_opencv.dtype, device=extri_opencv.device)), dim=0) |
|
return new_extri_opencv |
|
|
|
|
|
def adjust_intrinsic_matrix_rot90(intri_opencv: torch.Tensor, image_width: int, image_height: int, clockwise: bool): |
|
""" |
|
Adjusts the intrinsic matrix (3x3) for a 90-degree rotation of the image in the image plane. |
|
|
|
Args: |
|
intri_opencv (torch.Tensor): |
|
Intrinsic matrix (3x3). |
|
image_width (int): |
|
Original width of the image. |
|
image_height (int): |
|
Original height of the image. |
|
clockwise (bool): |
|
If True, rotate 90 degrees clockwise; else 90 degrees counterclockwise. |
|
|
|
Returns: |
|
torch.Tensor: |
|
A new 3x3 intrinsic matrix after the rotation. |
|
""" |
|
intri_opencv = copy.deepcopy(intri_opencv) |
|
intri_opencv[0, :] *= image_width |
|
intri_opencv[1, :] *= image_height |
|
|
|
fx, fy, cx, cy = ( |
|
intri_opencv[0, 0], |
|
intri_opencv[1, 1], |
|
intri_opencv[0, 2], |
|
intri_opencv[1, 2], |
|
) |
|
|
|
new_intri_opencv = torch.eye(3, dtype=intri_opencv.dtype, device=intri_opencv.device) |
|
if clockwise: |
|
new_intri_opencv[0, 0] = fy |
|
new_intri_opencv[1, 1] = fx |
|
new_intri_opencv[0, 2] = image_height - cy |
|
new_intri_opencv[1, 2] = cx |
|
else: |
|
new_intri_opencv[0, 0] = fy |
|
new_intri_opencv[1, 1] = fx |
|
new_intri_opencv[0, 2] = cy |
|
new_intri_opencv[1, 2] = image_width - cx |
|
|
|
new_intri_opencv[0, :] /= image_height |
|
new_intri_opencv[1, :] /= image_width |
|
|
|
return new_intri_opencv |
|
|