Spaces:
Running
on
T4
Running
on
T4
from tkinter import FALSE | |
import cv2 | |
import ipdb # noqa: F401 | |
import numpy as np | |
import torch | |
from pytorch3d.renderer import PerspectiveCameras, RayBundle | |
from pytorch3d.transforms import Rotate, Translate | |
from diffusionsfm.utils.normalize import ( | |
compute_optical_axis_intersection, | |
intersect_skew_line_groups, | |
first_camera_transform, | |
intersect_skew_lines_high_dim, | |
) | |
from diffusionsfm.utils.distortion import apply_distortion_tensor | |
class Rays(object): | |
def __init__( | |
self, | |
rays=None, | |
origins=None, | |
directions=None, | |
moments=None, | |
segments=None, | |
depths=None, | |
moments_rescale=1.0, | |
ndc_coordinates=None, | |
crop_parameters=None, | |
num_patches_x=16, | |
num_patches_y=16, | |
distortion_coeffs=None, | |
camera_coordinate_rays=None, | |
mode=None, | |
unprojected=None, | |
depth_resolution=1, | |
row_form=False, | |
): | |
""" | |
Ray class to keep track of current ray representation. | |
Args: | |
rays: (..., 6). | |
origins: (..., 3). | |
directions: (..., 3). | |
moments: (..., 3). | |
mode: One of "ray", "plucker" or "segment". | |
moments_rescale: Rescale the moment component of the rays by a scalar. | |
ndc_coordinates: (..., 2): NDC coordinates of each ray. | |
""" | |
self.depth_resolution = depth_resolution | |
self.num_patches_x = num_patches_x | |
self.num_patches_y = num_patches_y | |
if rays is not None: | |
self.rays = rays | |
assert mode is not None | |
self._mode = mode | |
elif segments is not None: | |
if not row_form: | |
segments = Rays.patches_to_rows( | |
segments, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
depth_resolution=depth_resolution, | |
) | |
self.rays = torch.cat((origins, segments), dim=-1) | |
self._mode = "segment" | |
elif origins is not None and directions is not None: | |
self.rays = torch.cat((origins, directions), dim=-1) | |
self._mode = "ray" | |
elif directions is not None and moments is not None: | |
self.rays = torch.cat((directions, moments), dim=-1) | |
self._mode = "plucker" | |
else: | |
raise Exception("Invalid combination of arguments") | |
if depths is not None: | |
self._mode = mode | |
self.depths = depths | |
else: | |
self.depths = None | |
assert mode is not None | |
if unprojected is not None: | |
self.unprojected = unprojected | |
else: | |
self.unprojected = None | |
if moments_rescale != 1.0: | |
self.rescale_moments(moments_rescale) | |
if ndc_coordinates is not None: | |
self.ndc_coordinates = ndc_coordinates | |
elif crop_parameters is not None: | |
# (..., H, W, 2) | |
xy_grid = compute_ndc_coordinates( | |
crop_parameters, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
distortion_coeffs=distortion_coeffs, | |
)[..., :2] | |
xy_grid = xy_grid.reshape(*xy_grid.shape[:-3], -1, 2) | |
self.ndc_coordinates = xy_grid | |
else: | |
self.ndc_coordinates = None | |
if camera_coordinate_rays is not None: | |
self.camera_ray_coordinates = True | |
self.camera_coordinate_ray_directions = camera_coordinate_rays | |
else: | |
self.camera_ray_coordinates = False | |
def __getitem__(self, index): | |
cam_coord_rays = None | |
if self.camera_ray_coordinates: | |
cam_coord_rays = self.camera_coordinate_ray_directions[index] | |
return Rays( | |
rays=self.rays[index], | |
mode=self._mode, | |
camera_coordinate_rays=cam_coord_rays, | |
ndc_coordinates=( | |
self.ndc_coordinates[index] | |
if self.ndc_coordinates is not None | |
else None | |
), | |
num_patches_x=self.num_patches_x, | |
num_patches_y=self.num_patches_y, | |
depths=( | |
self.depths[index] | |
if self.ndc_coordinates is not None and self.depths is not None | |
else None | |
), | |
unprojected=( | |
self.unprojected[index] if self.ndc_coordinates is not None else None | |
), | |
depth_resolution=self.depth_resolution, | |
) | |
def __len__(self): | |
return self.rays.shape[0] | |
def to_spatial( | |
self, include_ndc_coordinates=False, include_depths=False, use_homogeneous=False | |
): | |
""" | |
Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W) | |
If use_homogeneous is True, then each 3D component will be 4D and normalized. | |
Returns: | |
torch.Tensor: (..., 6, H, W) | |
""" | |
if self._mode == "ray": | |
rays = self.to_plucker().rays | |
else: | |
rays = self.rays | |
*batch_dims, P, D = rays.shape | |
H = W = int(np.sqrt(P)) | |
assert H * W == P | |
if use_homogeneous: | |
rays_reshaped = rays.reshape(*batch_dims, P, D // 3, 3) | |
ones = torch.ones_like(rays_reshaped[..., :1]) | |
rays_reshaped = torch.cat((rays_reshaped, ones), dim=-1) | |
rays = torch.nn.functional.normalize(rays_reshaped, dim=-1) | |
D = (4 * D) // 3 | |
rays = rays.reshape(*batch_dims, P, D) | |
rays = torch.transpose(rays, -1, -2) # (..., 6, H * W) | |
rays = rays.reshape(*batch_dims, D, H, W) | |
if include_depths: | |
depths = self.depths.unsqueeze(1) | |
rays = torch.cat((rays, depths), dim=-3) | |
if include_ndc_coordinates: | |
ndc_coords = self.ndc_coordinates.transpose(-1, -2) # (..., 2, H * W) | |
ndc_coords = ndc_coords.reshape(*batch_dims, 2, H, W) | |
rays = torch.cat((rays, ndc_coords), dim=-3) | |
return rays | |
def to_spatial_with_camera_coordinate_rays( | |
self, | |
I_camera, | |
crop_params, | |
moments_rescale=1.0, | |
include_ndc_coordinates=False, | |
use_homogeneous=False, | |
): | |
""" | |
Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W) | |
Returns: | |
torch.Tensor: (..., 6, H, W) | |
""" | |
rays = self.to_spatial( | |
include_ndc_coordinates=include_ndc_coordinates, | |
use_homogeneous=use_homogeneous, | |
) | |
N, _, H, W = rays.shape | |
camera_coord_rays = ( | |
cameras_to_rays( | |
cameras=I_camera, | |
num_patches_x=H, | |
num_patches_y=W, | |
crop_parameters=crop_params, | |
) | |
.rescale_moments(1 / moments_rescale) | |
.get_directions() | |
) | |
self.camera_coordinate_ray_directions = camera_coord_rays | |
# camera_coord_rays = torch.stack(camera_coord_rays) | |
camera_coord_rays = torch.transpose(camera_coord_rays, -1, -2) | |
camera_coord_rays = camera_coord_rays.reshape(N, 3, H, W) | |
rays = torch.cat((camera_coord_rays, rays), dim=-3) | |
return rays | |
def rescale_moments(self, scale): | |
""" | |
Rescale the moment component of the rays by a scalar. Might be desirable since | |
moments may come from a very narrow distribution. | |
Note that this modifies in place! | |
""" | |
assert False, "Deprecated" | |
if self._mode == "plucker": | |
self.rays[..., 3:] *= scale | |
return self | |
else: | |
return self.to_plucker().rescale_moments(scale) | |
def to_spatial_with_camera_coordinate_rays_object( | |
self, | |
I_camera, | |
crop_params, | |
moments_rescale=1.0, | |
include_ndc_coordinates=False, | |
use_homogeneous=False, | |
): | |
""" | |
Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W) | |
Returns: | |
torch.Tensor: (..., 6, H, W) | |
""" | |
rays = self.to_spatial(include_ndc_coordinates, use_homogeneous=use_homogeneous) | |
N, _, H, W = rays.shape | |
camera_coord_rays = ( | |
cameras_to_rays( | |
cameras=I_camera, | |
num_patches_x=H, | |
num_patches_y=W, | |
crop_parameters=crop_params, | |
) | |
.rescale_moments(1 / moments_rescale) | |
.get_directions() | |
) | |
self.camera_coordinate_ray_directions = camera_coord_rays | |
camera_coord_rays = torch.transpose(camera_coord_rays, -1, -2) | |
camera_coord_rays = camera_coord_rays.reshape(N, 3, H, W) | |
def patches_to_rows(cls, x, num_patches_x=16, num_patches_y=16, depth_resolution=1): | |
B, P, C = x.shape | |
assert P == (depth_resolution**2 * num_patches_x * num_patches_y) | |
x = x.reshape( | |
B, | |
depth_resolution * num_patches_x, | |
depth_resolution * num_patches_y, | |
C, | |
) | |
new = x.unfold(1, depth_resolution, depth_resolution).unfold( | |
2, depth_resolution, depth_resolution | |
) | |
new = new.permute((0, 1, 2, 4, 5, 3)) | |
new = new.reshape( | |
(B, num_patches_x * num_patches_y, depth_resolution * depth_resolution * C) | |
) | |
return new | |
def rows_to_patches(cls, x, num_patches_x=16, num_patches_y=16, depth_resolution=1): | |
B, P, CP = x.shape | |
assert P == (num_patches_x * num_patches_y) | |
C = CP // (depth_resolution**2) | |
HP, WP = num_patches_x * depth_resolution, num_patches_y * depth_resolution | |
x = x.reshape( | |
B, num_patches_x, num_patches_y, depth_resolution, depth_resolution, C | |
) | |
x = x.permute(0, 1, 3, 2, 4, 5) | |
x = x.reshape(B, HP * WP, C) | |
return x | |
def upsample_origins( | |
cls, x, num_patches_x=16, num_patches_y=16, depth_resolution=1 | |
): | |
B, P, C = x.shape | |
origins = x.permute((0, 2, 1)) | |
origins = origins.reshape((B, C, num_patches_x, num_patches_y)) | |
origins = torch.nn.functional.interpolate( | |
origins, scale_factor=(depth_resolution, depth_resolution) | |
) | |
origins = origins.permute((0, 2, 3, 1)).reshape( | |
(B, P * depth_resolution * depth_resolution, C) | |
) | |
return origins | |
def from_spatial_with_camera_coordinate_rays( | |
cls, rays, mode, moments_rescale=1.0, use_homogeneous=False | |
): | |
""" | |
Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6) | |
Args: | |
rays: (..., 6, H, W) | |
Returns: | |
Rays: (..., H * W, 6) | |
""" | |
*batch_dims, D, H, W = rays.shape | |
rays = rays.reshape(*batch_dims, D, H * W) | |
rays = torch.transpose(rays, -1, -2) | |
camera_coordinate_ray_directions = rays[..., :3] | |
rays = rays[..., 3:] | |
return cls( | |
rays=rays, | |
mode=mode, | |
moments_rescale=moments_rescale, | |
camera_coordinate_rays=camera_coordinate_ray_directions, | |
) | |
def from_spatial( | |
cls, | |
rays, | |
mode, | |
moments_rescale=1.0, | |
ndc_coordinates=None, | |
num_patches_x=16, | |
num_patches_y=16, | |
use_homogeneous=False, | |
): | |
""" | |
Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6) | |
Args: | |
rays: (..., 6, H, W) | |
Returns: | |
Rays: (..., H * W, 6) | |
""" | |
*batch_dims, D, H, W = rays.shape | |
rays = rays.reshape(*batch_dims, D, H * W) | |
rays = torch.transpose(rays, -1, -2) | |
if use_homogeneous: | |
D -= 2 | |
if D == 7: | |
if use_homogeneous: | |
r1 = rays[..., :3] / (rays[..., 3:4] + 1e-6) | |
r2 = rays[..., 4:7] / (rays[..., 7:8] + 1e-6) | |
rays = torch.cat((r1, r2), dim=-1) | |
depths = rays[8] | |
else: | |
old_rays = rays | |
rays = rays[..., :6] | |
depths = old_rays[..., 6] | |
return cls( | |
rays=rays, | |
mode=mode, | |
moments_rescale=moments_rescale, | |
ndc_coordinates=ndc_coordinates, | |
depths=depths.reshape(*batch_dims, H, W), | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
) | |
elif D > 7: | |
D += 2 | |
if use_homogeneous: | |
rays_reshaped = rays.reshape((*batch_dims, H * W, D // 4, 4)) | |
rays_not_homo = rays_reshaped / rays_reshaped[..., :, 3].unsqueeze(-1) | |
rays = rays_not_homo[..., :, :3].reshape( | |
(*batch_dims, H * W, (D // 4) * 3) | |
) | |
D = (D // 4) * 3 | |
ray = cls( | |
origins=rays[:, :, :3], | |
segments=rays[:, :, 3:], | |
mode="segment", | |
moments_rescale=moments_rescale, | |
ndc_coordinates=ndc_coordinates, | |
# depths=rays[:, :, -1].reshape(*batch_dims, H, W), | |
row_form=True, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
depth_resolution=int(((D - 3) // 3) ** 0.5), | |
) | |
if mode == "ray": | |
return ray.to_point_direction() | |
elif mode == "plucker": | |
return ray.to_plucker() | |
elif mode == "segment": | |
return ray | |
else: | |
assert False | |
else: | |
if use_homogeneous: | |
r1 = rays[..., :3] / (rays[..., 3:4] + 1e-6) | |
r2 = rays[..., 4:7] / (rays[..., 7:8] + 1e-6) | |
rays = torch.cat((r1, r2), dim=-1) | |
return cls( | |
rays=rays, | |
mode=mode, | |
moments_rescale=moments_rescale, | |
ndc_coordinates=ndc_coordinates, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
) | |
def to_point_direction(self, normalize_moment=True): | |
""" | |
Convert to point direction representation <O, D>. | |
Returns: | |
rays: (..., 6). | |
""" | |
if self._mode == "plucker": | |
direction = torch.nn.functional.normalize(self.rays[..., :3], dim=-1) | |
moment = self.rays[..., 3:] | |
if normalize_moment: | |
c = torch.linalg.norm(direction, dim=-1, keepdim=True) | |
moment = moment / c | |
points = torch.cross(direction, moment, dim=-1) | |
return Rays( | |
rays=torch.cat((points, direction), dim=-1), | |
mode="ray", | |
ndc_coordinates=self.ndc_coordinates, | |
num_patches_x=self.num_patches_x, | |
num_patches_y=self.num_patches_y, | |
depths=self.depths, | |
unprojected=self.unprojected, | |
depth_resolution=self.depth_resolution, | |
) | |
elif self._mode == "segment": | |
origins = self.get_origins(high_res=True) | |
direction = self.get_segments() - origins | |
direction = torch.nn.functional.normalize(direction, dim=-1) | |
return Rays( | |
rays=torch.cat((origins, direction), dim=-1), | |
mode="ray", | |
ndc_coordinates=self.ndc_coordinates, | |
num_patches_x=self.num_patches_x, | |
num_patches_y=self.num_patches_y, | |
depths=self.depths, | |
unprojected=self.unprojected, | |
depth_resolution=self.depth_resolution, | |
) | |
else: | |
return self | |
def to_plucker(self): | |
""" | |
Convert to plucker representation <D, OxD>. | |
""" | |
if self._mode == "plucker": | |
return self | |
elif self._mode == "ray": | |
ray = self.rays.clone() | |
ray_origins = ray[..., :3] | |
ray_directions = ray[..., 3:] | |
# Normalize ray directions to unit vectors | |
ray_directions = ray_directions / torch.linalg.vector_norm( | |
ray_directions, dim=-1, keepdim=True | |
) | |
plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) | |
new_ray = torch.cat([ray_directions, plucker_normal], dim=-1) | |
return Rays( | |
rays=new_ray, | |
mode="plucker", | |
ndc_coordinates=self.ndc_coordinates, | |
num_patches_x=self.num_patches_x, | |
num_patches_y=self.num_patches_y, | |
depths=self.depths, | |
unprojected=self.unprojected, | |
depth_resolution=self.depth_resolution, | |
) | |
elif self._mode == "segment": | |
return self.to_point_direction().to_plucker() | |
def get_directions(self, normalize=True): | |
if self._mode == "plucker": | |
directions = self.rays[..., :3] | |
elif self._mode == "segment": | |
directions = self.to_point_direction().get_directions() | |
else: | |
directions = self.rays[..., 3:] | |
if normalize: | |
directions = torch.nn.functional.normalize(directions, dim=-1) | |
return directions | |
def get_camera_coordinate_rays(self, normalize=True): | |
directions = self.camera_coordinate_ray_directions | |
if normalize: | |
directions = torch.nn.functional.normalize(directions, dim=-1) | |
return directions | |
def get_origins(self, high_res=False): | |
if self._mode == "plucker": | |
origins = self.to_point_direction().get_origins(high_res=high_res) | |
elif self._mode == "ray": | |
origins = self.rays[..., :3] | |
elif self._mode == "segment": | |
origins = Rays.upsample_origins( | |
self.rays[..., :3], | |
num_patches_x=self.num_patches_x, | |
num_patches_y=self.num_patches_y, | |
depth_resolution=self.depth_resolution, | |
) | |
else: | |
assert False | |
return origins | |
def get_moments(self): | |
if self._mode == "plucker": | |
moments = self.rays[..., 3:] | |
elif self._mode in ["ray", "segment"]: | |
moments = self.to_plucker().get_moments() | |
return moments | |
def get_segments(self): | |
assert self._mode == "segment" | |
if self.unprojected is not None: | |
return self.unprojected | |
else: | |
return Rays.rows_to_patches( | |
self.rays[..., 3:], | |
num_patches_x=self.num_patches_x, | |
num_patches_y=self.num_patches_y, | |
depth_resolution=self.depth_resolution, | |
) | |
def get_ndc_coordinates(self): | |
return self.ndc_coordinates | |
def mode(self): | |
return self._mode | |
def mode(self, mode): | |
self._mode = mode | |
def device(self): | |
return self.rays.device | |
def __repr__(self, *args, **kwargs): | |
ray_str = self.rays.__repr__(*args, **kwargs)[6:] # remove "tensor" | |
if self._mode == "plucker": | |
return "PluRay" + ray_str | |
elif self._mode == "ray": | |
return "DirRay" + ray_str | |
else: | |
return "SegRay" + ray_str | |
def to(self, device): | |
self.rays = self.rays.to(device) | |
def clone(self): | |
return Rays(rays=self.rays.clone(), mode=self._mode) | |
def shape(self): | |
return self.rays.shape | |
def visualize(self): | |
directions = torch.nn.functional.normalize(self.get_directions(), dim=-1).cpu() | |
moments = torch.nn.functional.normalize(self.get_moments(), dim=-1).cpu() | |
return (directions + 1) / 2, (moments + 1) / 2 | |
def to_ray_bundle(self, length=0.3, recompute_origin=False, true_length=False): | |
""" | |
Args: | |
length (float): Length of the rays for visualization. | |
recompute_origin (bool): If True, origin is set to the intersection point of | |
all rays. If False, origins are the point along the ray closest | |
""" | |
origins = self.get_origins(high_res=self.depth_resolution > 1) | |
lengths = torch.ones_like(origins[..., :2]) * length | |
lengths[..., 0] = 0 | |
p_intersect, p_closest, _, _ = intersect_skew_line_groups( | |
origins.float(), self.get_directions().float() | |
) | |
if recompute_origin: | |
centers = p_intersect | |
centers = centers.unsqueeze(1).repeat(1, lengths.shape[1], 1) | |
else: | |
centers = p_closest | |
if true_length: | |
length = torch.norm(self.get_segments() - centers, dim=-1).unsqueeze(-1) | |
lengths = torch.ones_like(origins[..., :2]) * length | |
lengths[..., 0] = 0 | |
return RayBundle( | |
origins=centers, | |
directions=self.get_directions(), | |
lengths=lengths, | |
xys=self.get_directions(), | |
) | |
def cameras_to_rays( | |
cameras, | |
crop_parameters, | |
use_half_pix=True, | |
use_plucker=True, | |
num_patches_x=16, | |
num_patches_y=16, | |
no_crop_param_device="cpu", | |
distortion_coeffs=None, | |
depths=None, | |
visualize=False, | |
mode=None, | |
depth_resolution=1, | |
nearest_neighbor=True, | |
distortion_coefficients=None, | |
): | |
""" | |
Unprojects rays from camera center to grid on image plane. | |
To match Moneish's code, set use_half_pix=False, use_plucker=True. Also, the | |
arguments to meshgrid should be swapped (x first, then y). I'm following Pytorch3d | |
convention to have y first. | |
distortion_coeffs refers to Amy's distortion experiments | |
distortion_coefficients refers to the fisheye parameters from colmap | |
Args: | |
cameras: Pytorch3D cameras to unproject. Can be batched. | |
crop_parameters: Crop parameters in NDC (cc_x, cc_y, crop_width, scale). | |
Shape is (B, 4). | |
use_half_pix: If True, use half pixel offset (Default: True). | |
use_plucker: If True, return rays in plucker coordinates (Default: False). | |
num_patches_x: Number of patches in x direction (Default: 16). | |
num_patches_y: Number of patches in y direction (Default: 16). | |
""" | |
unprojected = [] | |
unprojected_ones = [] | |
crop_parameters_list = ( | |
crop_parameters if crop_parameters is not None else [None for _ in cameras] | |
) | |
depths_list = depths if depths is not None else [None for _ in cameras] | |
if distortion_coeffs is None: | |
zs = [] | |
for i, (camera, crop_param, depth) in enumerate( | |
zip(cameras, crop_parameters_list, depths_list) | |
): | |
xyd_grid = compute_ndc_coordinates( | |
crop_parameters=crop_param, | |
use_half_pix=use_half_pix, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
no_crop_param_device=no_crop_param_device, | |
depths=depth, | |
return_zs=True, | |
depth_resolution=depth_resolution, | |
nearest_neighbor=nearest_neighbor, | |
) | |
xyd_grid, z, ones_grid = xyd_grid | |
zs.append(z) | |
if ( | |
distortion_coefficients is not None | |
and (distortion_coefficients[i] != 0).any() | |
): | |
xyd_grid = undistort_ndc_coordinates( | |
ndc_coordinates=xyd_grid, | |
principal_point=camera.principal_point[0], | |
focal_length=camera.focal_length[0], | |
distortion_coefficients=distortion_coefficients[i], | |
) | |
unprojected.append( | |
camera.unproject_points( | |
xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True | |
) | |
) | |
if depths is not None and mode == "plucker": | |
unprojected_ones.append( | |
camera.unproject_points( | |
ones_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True | |
) | |
) | |
else: | |
for camera, crop_param, distort_coeff in zip( | |
cameras, crop_parameters_list, distortion_coeffs | |
): | |
xyd_grid = compute_ndc_coordinates( | |
crop_parameters=crop_param, | |
use_half_pix=use_half_pix, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
no_crop_param_device=no_crop_param_device, | |
distortion_coeffs=distort_coeff, | |
depths=depths, | |
nearest_neighbor=nearest_neighbor, | |
) | |
unprojected.append( | |
camera.unproject_points( | |
xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True | |
) | |
) | |
unprojected = torch.stack(unprojected, dim=0) # (N, P, 3) | |
origins = cameras.get_camera_center().unsqueeze(1) # (N, 1, 3) | |
origins = origins.repeat(1, num_patches_x * num_patches_y, 1) # (N, P, 3) | |
if depths is None: | |
directions = unprojected - origins | |
rays = Rays( | |
origins=origins, | |
directions=directions, | |
crop_parameters=crop_parameters, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
distortion_coeffs=distortion_coeffs, | |
mode="ray", | |
unprojected=unprojected, | |
) | |
if use_plucker: | |
return rays.to_plucker() | |
elif mode == "segment": | |
rays = Rays( | |
origins=origins, | |
segments=unprojected, | |
crop_parameters=crop_parameters, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
distortion_coeffs=distortion_coeffs, | |
depths=torch.stack(zs, dim=0), | |
mode=mode, | |
unprojected=unprojected, | |
depth_resolution=depth_resolution, | |
) | |
elif mode == "plucker" or mode == "ray": | |
unprojected_ones = torch.stack(unprojected_ones) | |
directions = unprojected_ones - origins | |
rays = Rays( | |
origins=origins, | |
directions=directions, | |
crop_parameters=crop_parameters, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
distortion_coeffs=distortion_coeffs, | |
depths=torch.stack(zs, dim=0), | |
mode="ray", | |
unprojected=unprojected, | |
) | |
if mode == "plucker": | |
rays = rays.to_plucker() | |
else: | |
assert False | |
if visualize: | |
return rays, unprojected, torch.stack(zs, dim=0) | |
return rays | |
def rays_to_cameras( | |
rays, | |
crop_parameters, | |
num_patches_x=16, | |
num_patches_y=16, | |
use_half_pix=True, | |
no_crop_param_device="cpu", | |
sampled_ray_idx=None, | |
cameras=None, | |
focal_length=(3.453,), | |
distortion_coeffs=None, | |
calculate_distortion=False, | |
depth_resolution=1, | |
average_centers=False, | |
): | |
""" | |
If cameras are provided, will use those intrinsics. Otherwise will use the provided | |
focal_length(s). Dataset default is 3.32. | |
Args: | |
rays (Rays): (N, P, 6) | |
crop_parameters (torch.Tensor): (N, 4) | |
""" | |
device = rays.device | |
origins = rays.get_origins(high_res=True) | |
directions = rays.get_directions() | |
if average_centers: | |
camera_centers = torch.mean(origins, dim=1) | |
else: | |
camera_centers, _ = intersect_skew_lines_high_dim(origins, directions) | |
# Retrieve target rays | |
if cameras is None: | |
if len(focal_length) == 1: | |
focal_length = focal_length * rays.shape[0] | |
I_camera = PerspectiveCameras(focal_length=focal_length, device=device) | |
else: | |
# Use same intrinsics but reset to identity extrinsics. | |
I_camera = cameras.clone() | |
I_camera.R[:] = torch.eye(3, device=device) | |
I_camera.T[:] = torch.zeros(3, device=device) | |
if distortion_coeffs is not None and not calculate_distortion: | |
coeff = distortion_coeffs | |
else: | |
coeff = None | |
I_patch_rays = cameras_to_rays( | |
cameras=I_camera, | |
num_patches_x=num_patches_x * depth_resolution, | |
num_patches_y=num_patches_y * depth_resolution, | |
use_half_pix=use_half_pix, | |
crop_parameters=crop_parameters, | |
no_crop_param_device=no_crop_param_device, | |
distortion_coeffs=coeff, | |
mode="plucker", | |
depth_resolution=depth_resolution, | |
).get_directions() | |
if sampled_ray_idx is not None: | |
I_patch_rays = I_patch_rays[:, sampled_ray_idx] | |
# Compute optimal rotation to align rays | |
R = torch.zeros_like(I_camera.R) | |
for i in range(len(I_camera)): | |
R[i] = compute_optimal_rotation_alignment( | |
I_patch_rays[i], | |
directions[i], | |
) | |
# Construct and return rotated camera | |
cam = I_camera.clone() | |
cam.R = R | |
cam.T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2) | |
return cam | |
# https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/ | |
def ql_decomposition(A): | |
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float() | |
A_tilde = torch.matmul(A, P) | |
Q_tilde, R_tilde = torch.linalg.qr(A_tilde) | |
Q = torch.matmul(Q_tilde, P) | |
L = torch.matmul(torch.matmul(P, R_tilde), P) | |
d = torch.diag(L) | |
Q[:, 0] *= torch.sign(d[0]) | |
Q[:, 1] *= torch.sign(d[1]) | |
Q[:, 2] *= torch.sign(d[2]) | |
L[0] *= torch.sign(d[0]) | |
L[1] *= torch.sign(d[1]) | |
L[2] *= torch.sign(d[2]) | |
return Q, L | |
def rays_to_cameras_homography( | |
rays, | |
crop_parameters, | |
num_patches_x=16, | |
num_patches_y=16, | |
use_half_pix=True, | |
sampled_ray_idx=None, | |
reproj_threshold=0.2, | |
camera_coordinate_rays=False, | |
average_centers=False, | |
depth_resolution=1, | |
directions_from_averaged_center=False, | |
): | |
""" | |
Args: | |
rays (Rays): (N, P, 6) | |
crop_parameters (torch.Tensor): (N, 4) | |
""" | |
device = rays.device | |
origins = rays.get_origins(high_res=True) | |
directions = rays.get_directions() | |
if average_centers: | |
camera_centers = torch.mean(origins, dim=1) | |
else: | |
camera_centers, _ = intersect_skew_lines_high_dim(origins, directions) | |
if directions_from_averaged_center: | |
assert rays.mode == "segment" | |
directions = rays.get_segments() - camera_centers.unsqueeze(1).repeat( | |
(1, num_patches_x * num_patches_y, 1) | |
) | |
# Retrieve target rays | |
I_camera = PerspectiveCameras(focal_length=[1] * rays.shape[0], device=device) | |
I_patch_rays = cameras_to_rays( | |
cameras=I_camera, | |
num_patches_x=num_patches_x * depth_resolution, | |
num_patches_y=num_patches_y * depth_resolution, | |
use_half_pix=use_half_pix, | |
crop_parameters=crop_parameters, | |
no_crop_param_device=device, | |
mode="plucker", | |
).get_directions() | |
if sampled_ray_idx is not None: | |
I_patch_rays = I_patch_rays[:, sampled_ray_idx] | |
# Compute optimal rotation to align rays | |
if camera_coordinate_rays: | |
directions_used = rays.get_camera_coordinate_rays() | |
else: | |
directions_used = directions | |
Rs = [] | |
focal_lengths = [] | |
principal_points = [] | |
for i in range(rays.shape[-3]): | |
R, f, pp = compute_optimal_rotation_intrinsics( | |
I_patch_rays[i], | |
directions_used[i], | |
reproj_threshold=reproj_threshold, | |
) | |
Rs.append(R) | |
focal_lengths.append(f) | |
principal_points.append(pp) | |
R = torch.stack(Rs) | |
focal_lengths = torch.stack(focal_lengths) | |
principal_points = torch.stack(principal_points) | |
T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2) | |
return PerspectiveCameras( | |
R=R, | |
T=T, | |
focal_length=focal_lengths, | |
principal_point=principal_points, | |
device=device, | |
) | |
def compute_optimal_rotation_alignment(A, B): | |
""" | |
Compute optimal R that minimizes: || A - B @ R ||_F | |
Args: | |
A (torch.Tensor): (N, 3) | |
B (torch.Tensor): (N, 3) | |
Returns: | |
R (torch.tensor): (3, 3) | |
""" | |
# normally with R @ B, this would be A @ B.T | |
H = B.T @ A | |
U, _, Vh = torch.linalg.svd(H, full_matrices=True) | |
s = torch.linalg.det(U @ Vh) | |
S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device)) | |
return U @ S_prime @ Vh | |
def compute_optimal_rotation_intrinsics( | |
rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2 | |
): | |
""" | |
Note: for some reason, f seems to be 1/f. | |
Args: | |
rays_origin (torch.Tensor): (N, 3) | |
rays_target (torch.Tensor): (N, 3) | |
z_threshold (float): Threshold for z value to be considered valid. | |
Returns: | |
R (torch.tensor): (3, 3) | |
focal_length (torch.tensor): (2,) | |
principal_point (torch.tensor): (2,) | |
""" | |
device = rays_origin.device | |
z_mask = torch.logical_and( | |
torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold | |
)[:, 2] | |
rays_target = rays_target[z_mask] | |
rays_origin = rays_origin[z_mask] | |
rays_origin = rays_origin[:, :2] / rays_origin[:, -1:] | |
rays_target = rays_target[:, :2] / rays_target[:, -1:] | |
try: | |
A, _ = cv2.findHomography( | |
rays_origin.cpu().numpy(), | |
rays_target.cpu().numpy(), | |
cv2.RANSAC, | |
reproj_threshold, | |
) | |
except: | |
A, _ = cv2.findHomography( | |
rays_origin.cpu().numpy(), | |
rays_target.cpu().numpy(), | |
cv2.RANSAC, | |
reproj_threshold, | |
) | |
A = torch.from_numpy(A).float().to(device) | |
if torch.linalg.det(A) < 0: | |
# TODO: Find a better fix for this. This gives the correct R but incorrect | |
# intrinsics. | |
A = -A | |
R, L = ql_decomposition(A) | |
L = L / L[2][2] | |
f = torch.stack((L[0][0], L[1][1])) | |
# f = torch.stack(((L[0][0] + L[1][1]) / 2, (L[0][0] + L[1][1]) / 2)) | |
pp = torch.stack((L[2][0], L[2][1])) | |
return R, f, pp | |
def compute_ndc_coordinates( | |
crop_parameters=None, | |
use_half_pix=True, | |
num_patches_x=16, | |
num_patches_y=16, | |
no_crop_param_device="cpu", | |
distortion_coeffs=None, | |
depths=None, | |
return_zs=False, | |
depth_resolution=1, | |
nearest_neighbor=True, | |
): | |
""" | |
Computes NDC Grid using crop_parameters. If crop_parameters is not provided, | |
then it assumes that the crop is the entire image (corresponding to an NDC grid | |
where top left corner is (1, 1) and bottom right corner is (-1, -1)). | |
""" | |
if crop_parameters is None: | |
cc_x, cc_y, width = 0, 0, 2 | |
device = no_crop_param_device | |
else: | |
if len(crop_parameters.shape) > 1: | |
if distortion_coeffs is None: | |
return torch.stack( | |
[ | |
compute_ndc_coordinates( | |
crop_parameters=crop_param, | |
use_half_pix=use_half_pix, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
nearest_neighbor=nearest_neighbor, | |
depths=depths[i] if depths is not None else None, | |
) | |
for i, crop_param in enumerate(crop_parameters) | |
], | |
dim=0, | |
) | |
else: | |
patch_params = zip(crop_parameters, distortion_coeffs) | |
return torch.stack( | |
[ | |
compute_ndc_coordinates( | |
crop_parameters=crop_param, | |
use_half_pix=use_half_pix, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
distortion_coeffs=distortion_coeff, | |
nearest_neighbor=nearest_neighbor, | |
) | |
for crop_param, distortion_coeff in patch_params | |
], | |
dim=0, | |
) | |
device = crop_parameters.device | |
cc_x, cc_y, width, _ = crop_parameters | |
dx = 1 / num_patches_x | |
dy = 1 / num_patches_y | |
if use_half_pix: | |
min_y = 1 - dy | |
max_y = -min_y | |
min_x = 1 - dx | |
max_x = -min_x | |
else: | |
min_y = min_x = 1 | |
max_y = -1 + 2 * dy | |
max_x = -1 + 2 * dx | |
y, x = torch.meshgrid( | |
torch.linspace(min_y, max_y, num_patches_y, dtype=torch.float32, device=device), | |
torch.linspace(min_x, max_x, num_patches_x, dtype=torch.float32, device=device), | |
indexing="ij", | |
) | |
x_prime = x * width / 2 - cc_x | |
y_prime = y * width / 2 - cc_y | |
if distortion_coeffs is not None: | |
points = torch.cat( | |
(x_prime.flatten().unsqueeze(-1), y_prime.flatten().unsqueeze(-1)), | |
dim=-1, | |
) | |
new_points = apply_distortion_tensor( | |
points, distortion_coeffs[0], distortion_coeffs[1] | |
) | |
x_prime = new_points[:, 0].reshape((num_patches_x, num_patches_y)) | |
y_prime = new_points[:, 1].reshape((num_patches_x, num_patches_y)) | |
if depths is not None: | |
if depth_resolution > 1: | |
high_res_grid = compute_ndc_coordinates( | |
crop_parameters=crop_parameters, | |
use_half_pix=use_half_pix, | |
num_patches_x=num_patches_x * depth_resolution, | |
num_patches_y=num_patches_y * depth_resolution, | |
no_crop_param_device=no_crop_param_device, | |
) | |
x_prime = high_res_grid[..., 0] | |
y_prime = high_res_grid[..., 1] | |
z = depths | |
xyd_grid = torch.stack([x_prime, y_prime, z], dim=-1) | |
else: | |
z = torch.ones_like(x) | |
xyd_grid = torch.stack([x_prime, y_prime, z], dim=-1) | |
xyd_grid_ones = torch.stack([x_prime, y_prime, torch.ones_like(x_prime)], dim=-1) | |
if return_zs: | |
return xyd_grid, z, xyd_grid_ones | |
return xyd_grid | |
def undistort_ndc_coordinates( | |
ndc_coordinates, principal_point, focal_length, distortion_coefficients | |
): | |
""" | |
Given NDC coordinates from a fisheye camera, computes where the coordinates would | |
have been for a pinhole camera. | |
Args: | |
ndc_coordinates (torch.Tensor): (H, W, 3) | |
principal_point (torch.Tensor): (2,) | |
focal_length (torch.Tensor): (2,) | |
distortion_coefficients (torch.Tensor): (4,) | |
Returns: | |
torch.Tensor: (H, W, 3) | |
""" | |
device = ndc_coordinates.device | |
x = ndc_coordinates[..., 0] | |
y = ndc_coordinates[..., 1] | |
d = ndc_coordinates[..., 2] | |
# Compute normalized coordinates (using opencv convention where negative is top-left | |
x = -(x - principal_point[0]) / focal_length[0] | |
y = -(y - principal_point[1]) / focal_length[1] | |
distorted = torch.stack((x.flatten(), y.flatten()), 1).unsqueeze(1).cpu().numpy() | |
undistorted = cv2.fisheye.undistortPoints( | |
distorted, np.eye(3), distortion_coefficients.cpu().numpy(), np.eye(3) | |
) | |
u = torch.tensor(undistorted[:, 0, 0], device=device) | |
v = torch.tensor(undistorted[:, 0, 1], device=device) | |
new_x = -u * focal_length[0] + principal_point[0] | |
new_y = -v * focal_length[1] + principal_point[1] | |
return torch.stack((new_x.reshape(x.shape), new_y.reshape(y.shape), d), -1) | |
def get_identity_cameras_with_intrinsics(cameras): | |
D = len(cameras) | |
device = cameras.R.device | |
new_cameras = cameras.clone() | |
new_cameras.R = torch.eye(3, device=device).unsqueeze(0).repeat((D, 1, 1)) | |
new_cameras.T = torch.zeros((D, 3), device=device) | |
return new_cameras | |
def normalize_cameras_batch( | |
cameras, | |
scale=1.0, | |
normalize_first_camera=False, | |
depths=None, | |
crop_parameters=None, | |
num_patches_x=16, | |
num_patches_y=16, | |
distortion_coeffs=[None], | |
first_cam_mediod=False, | |
return_scales=False, | |
): | |
new_cameras = [] | |
undo_transforms = [] | |
scales = [] | |
for i, cam in enumerate(cameras): | |
if normalize_first_camera: | |
# Normalize cameras such that first camera is identity and origin is at | |
# first camera center. | |
s = 1 | |
if first_cam_mediod: | |
s = scale_first_cam_mediod( | |
cam[0], | |
depths=depths[i][0].unsqueeze(0) if depths is not None else None, | |
crop_parameters=crop_parameters[i][0].unsqueeze(0), | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
distortion_coeffs=( | |
distortion_coeffs[i][0].unsqueeze(0) | |
if distortion_coeffs[i] is not None | |
else None | |
), | |
) | |
scales.append(s) | |
normalized_cameras = first_camera_transform(cam, s, rotation_only=False) | |
undo_transform = None | |
else: | |
out = normalize_cameras(cam, scale=scale, return_scale=depths is not None) | |
normalized_cameras, undo_transform, s = out | |
if depths is not None: | |
depths[i] *= s | |
if depths.isnan().any(): | |
assert False | |
new_cameras.append(normalized_cameras) | |
undo_transforms.append(undo_transform) | |
if return_scales: | |
return new_cameras, undo_transforms, scales | |
return new_cameras, undo_transforms | |
def scale_first_cam_mediod( | |
cameras, | |
scale=1.0, | |
return_scale=False, | |
depths=None, | |
crop_parameters=None, | |
num_patches_x=16, | |
num_patches_y=16, | |
distortion_coeffs=None, | |
): | |
xy_grid = ( | |
compute_ndc_coordinates( | |
depths=depths, | |
crop_parameters=crop_parameters, | |
num_patches_x=num_patches_x, | |
num_patches_y=num_patches_y, | |
distortion_coeffs=distortion_coeffs, | |
) | |
.reshape((-1, 3)) | |
.to(depths.device) | |
) | |
verts = cameras.unproject_points(xy_grid, from_ndc=True, world_coordinates=True) | |
p_intersect = torch.median( | |
verts.reshape((-1, 3))[: num_patches_x * num_patches_y].float(), dim=0 | |
).values.unsqueeze(0) | |
d = torch.norm(p_intersect - cameras.get_camera_center()) | |
if d < 0.001: | |
return 1 | |
return 1 / d | |
def normalize_cameras(cameras, scale=1.0, return_scale=False): | |
""" | |
Normalizes cameras such that the optical axes point to the origin, the rotation is | |
identity, and the norm of the translation of the first camera is 1. | |
Args: | |
cameras (pytorch3d.renderer.cameras.CamerasBase). | |
scale (float): Norm of the translation of the first camera. | |
Returns: | |
new_cameras (pytorch3d.renderer.cameras.CamerasBase): Normalized cameras. | |
undo_transform (function): Function that undoes the normalization. | |
""" | |
# Let distance from first camera to origin be unit | |
new_cameras = cameras.clone() | |
new_transform = ( | |
new_cameras.get_world_to_view_transform() | |
) # potential R is not valid matrix | |
p_intersect, dist, _, _, _ = compute_optical_axis_intersection(cameras) | |
if p_intersect is None: | |
print("Warning: optical axes code has a nan. Returning identity cameras.") | |
new_cameras.R[:] = torch.eye(3, device=cameras.R.device, dtype=cameras.R.dtype) | |
new_cameras.T[:] = torch.tensor( | |
[0, 0, 1], device=cameras.T.device, dtype=cameras.T.dtype | |
) | |
return new_cameras, lambda x: x, 1 / scale | |
d = dist.squeeze(dim=1).squeeze(dim=0)[0] | |
# Degenerate case | |
if d == 0: | |
print(cameras.T) | |
print(new_transform.get_matrix()[:, 3, :3]) | |
assert False | |
assert d != 0 | |
# Can't figure out how to make scale part of the transform too without messing up R. | |
# Ideally, we would just wrap it all in a single Pytorch3D transform so that it | |
# would work with any structure (eg PointClouds, Meshes). | |
tR = Rotate(new_cameras.R[0].unsqueeze(0)).inverse() | |
tT = Translate(p_intersect) | |
t = tR.compose(tT) | |
new_transform = t.compose(new_transform) | |
new_cameras.R = new_transform.get_matrix()[:, :3, :3] | |
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / d * scale | |
def undo_transform(cameras): | |
cameras_copy = cameras.clone() | |
cameras_copy.T *= d / scale | |
new_t = ( | |
t.inverse().compose(cameras_copy.get_world_to_view_transform()).get_matrix() | |
) | |
cameras_copy.R = new_t[:, :3, :3] | |
cameras_copy.T = new_t[:, 3, :3] | |
return cameras_copy | |
if return_scale: | |
return new_cameras, undo_transform, scale / d | |
return new_cameras, undo_transform | |
def first_camera_transform(cameras, s, rotation_only=True): | |
new_cameras = cameras.clone() | |
new_transform = new_cameras.get_world_to_view_transform() | |
tR = Rotate(new_cameras.R[0].unsqueeze(0)) | |
if rotation_only: | |
t = tR.inverse() | |
else: | |
tT = Translate(new_cameras.T[0].unsqueeze(0)) | |
t = tR.compose(tT).inverse() | |
new_transform = t.compose(new_transform) | |
new_cameras.R = new_transform.get_matrix()[:, :3, :3] | |
new_cameras.T = new_transform.get_matrix()[:, 3, :3] * s | |
return new_cameras | |