Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import math | |
import torch.nn.functional as F | |
from .forward_warp_utils_pytorch import unproject_points | |
def apply_transformation(Bx4x4, another_matrix): | |
B = Bx4x4.shape[0] | |
if another_matrix.dim() == 2: | |
another_matrix = another_matrix.unsqueeze(0).expand(B, -1, -1) # Make another_matrix compatible with batch size | |
transformed_matrix = torch.bmm(Bx4x4, another_matrix) # Shape: (B, 4, 4) | |
return transformed_matrix | |
def look_at_matrix(camera_pos, target, invert_pos=True): | |
"""Creates a 4x4 look-at matrix, keeping the camera pointing towards a target.""" | |
forward = (target - camera_pos).float() | |
forward = forward / torch.norm(forward) | |
up = torch.tensor([0.0, 1.0, 0.0], device=camera_pos.device) # assuming Y-up coordinate system | |
right = torch.cross(up, forward) | |
right = right / torch.norm(right) | |
up = torch.cross(forward, right) | |
look_at = torch.eye(4, device=camera_pos.device) | |
look_at[0, :3] = right | |
look_at[1, :3] = up | |
look_at[2, :3] = forward | |
look_at[:3, 3] = (-camera_pos) if invert_pos else camera_pos | |
return look_at | |
def create_horizontal_trajectory( | |
world_to_camera_matrix, center_depth, positive=True, n_steps=13, distance=0.1, device="cuda", axis="x", camera_rotation="center_facing" | |
): | |
look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) | |
# Spiral motion key points | |
trajectory = [] | |
translation_positions = [] | |
initial_camera_pos = torch.tensor([0, 0, 0], device=device) | |
for i in range(n_steps): | |
if axis == "x": # pos - right | |
x = i * distance * center_depth / n_steps * (1 if positive else -1) | |
y = 0 | |
z = 0 | |
elif axis == "y": # pos - down | |
x = 0 | |
y = i * distance * center_depth / n_steps * (1 if positive else -1) | |
z = 0 | |
elif axis == "z": # pos - in | |
x = 0 | |
y = 0 | |
z = i * distance * center_depth / n_steps * (1 if positive else -1) | |
else: | |
raise ValueError("Axis should be x, y or z") | |
translation_positions.append(torch.tensor([x, y, z], device=device)) | |
for pos in translation_positions: | |
camera_pos = initial_camera_pos + pos | |
if camera_rotation == "trajectory_aligned": | |
_look_at = look_at + pos * 2 | |
elif camera_rotation == "center_facing": | |
_look_at = look_at | |
elif camera_rotation == "no_rotation": | |
_look_at = look_at + pos | |
else: | |
raise ValueError("Camera rotation should be center_facing or trajectory_aligned") | |
view_matrix = look_at_matrix(camera_pos, _look_at) | |
trajectory.append(view_matrix) | |
trajectory = torch.stack(trajectory) | |
return apply_transformation(trajectory, world_to_camera_matrix) | |
def create_spiral_trajectory( | |
world_to_camera_matrix, | |
center_depth, | |
radius_x=0.03, | |
radius_y=0.02, | |
radius_z=0.0, | |
positive=True, | |
camera_rotation="center_facing", | |
n_steps=13, | |
device="cuda", | |
start_from_zero=True, | |
num_circles=1, | |
): | |
look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) | |
# Spiral motion key points | |
trajectory = [] | |
spiral_positions = [] | |
initial_camera_pos = torch.tensor([0, 0, 0], device=device) # world_to_camera_matrix[:3, 3].clone() | |
example_scale = 1.0 | |
theta_max = 2 * math.pi * num_circles | |
for i in range(n_steps): | |
# theta = 2 * math.pi * i / (n_steps-1) # angle for each point | |
theta = theta_max * i / (n_steps - 1) # angle for each point | |
if start_from_zero: | |
x = radius_x * (math.cos(theta) - 1) * (1 if positive else -1) * (center_depth / example_scale) | |
else: | |
x = radius_x * (math.cos(theta)) * (center_depth / example_scale) | |
y = radius_y * math.sin(theta) * (center_depth / example_scale) | |
z = radius_z * math.sin(theta) * (center_depth / example_scale) | |
spiral_positions.append(torch.tensor([x, y, z], device=device)) | |
for pos in spiral_positions: | |
if camera_rotation == "center_facing": | |
view_matrix = look_at_matrix(initial_camera_pos + pos, look_at) | |
elif camera_rotation == "trajectory_aligned": | |
view_matrix = look_at_matrix(initial_camera_pos + pos, look_at + pos * 2) | |
elif camera_rotation == "no_rotation": | |
view_matrix = look_at_matrix(initial_camera_pos + pos, look_at + pos) | |
else: | |
raise ValueError("Camera rotation should be center_facing, trajectory_aligned or no_rotation") | |
trajectory.append(view_matrix) | |
trajectory = torch.stack(trajectory) | |
return apply_transformation(trajectory, world_to_camera_matrix) | |
def generate_camera_trajectory( | |
trajectory_type: str, | |
initial_w2c: torch.Tensor, # Shape: (4, 4) | |
initial_intrinsics: torch.Tensor, # Shape: (3, 3) | |
num_frames: int, | |
movement_distance: float, | |
camera_rotation: str, | |
center_depth: float = 1.0, | |
device: str = "cuda", | |
): | |
""" | |
Generates a sequence of camera poses (world-to-camera matrices) and intrinsics | |
for a specified trajectory type. | |
Args: | |
trajectory_type: Type of trajectory (e.g., "left", "right", "up", "down", "zoom_in", "zoom_out"). | |
initial_w2c: Initial world-to-camera matrix (4x4 tensor or num_framesx4x4 tensor). | |
initial_intrinsics: Camera intrinsics matrix (3x3 tensor or num_framesx3x3 tensor). | |
num_frames: Number of frames (steps) in the trajectory. | |
movement_distance: Distance factor for the camera movement. | |
camera_rotation: Type of camera rotation ('center_facing', 'no_rotation', 'trajectory_aligned'). | |
center_depth: Depth of the center point the camera might focus on. | |
device: Computation device ("cuda" or "cpu"). | |
Returns: | |
A tuple (generated_w2cs, generated_intrinsics): | |
- generated_w2cs: Batch of world-to-camera matrices for the trajectory (1, num_frames, 4, 4 tensor). | |
- generated_intrinsics: Batch of camera intrinsics for the trajectory (1, num_frames, 3, 3 tensor). | |
""" | |
if trajectory_type in ["clockwise", "counterclockwise"]: | |
new_w2cs_seq = create_spiral_trajectory( | |
world_to_camera_matrix=initial_w2c, | |
center_depth=center_depth, | |
n_steps=num_frames, | |
positive=trajectory_type == "clockwise", | |
device=device, | |
camera_rotation=camera_rotation, | |
radius_x=movement_distance, | |
radius_y=movement_distance, | |
) | |
else: | |
if trajectory_type == "left": | |
positive = False | |
axis = "x" | |
elif trajectory_type == "right": | |
positive = True | |
axis = "x" | |
elif trajectory_type == "up": | |
positive = False # Assuming 'up' means camera moves in negative y direction if y points down | |
axis = "y" | |
elif trajectory_type == "down": | |
positive = True # Assuming 'down' means camera moves in positive y direction if y points down | |
axis = "y" | |
elif trajectory_type == "zoom_in": | |
positive = True # Assuming 'zoom_in' means camera moves in positive z direction (forward) | |
axis = "z" | |
elif trajectory_type == "zoom_out": | |
positive = False # Assuming 'zoom_out' means camera moves in negative z direction (backward) | |
axis = "z" | |
else: | |
raise ValueError(f"Unsupported trajectory type: {trajectory_type}") | |
# Generate world-to-camera matrices using create_horizontal_trajectory | |
new_w2cs_seq = create_horizontal_trajectory( | |
world_to_camera_matrix=initial_w2c, | |
center_depth=center_depth, | |
n_steps=num_frames, | |
positive=positive, | |
axis=axis, | |
distance=movement_distance, | |
device=device, | |
camera_rotation=camera_rotation, | |
) | |
generated_w2cs = new_w2cs_seq.unsqueeze(0) # Shape: [1, num_frames, 4, 4] | |
if initial_intrinsics.dim() == 2: | |
generated_intrinsics = initial_intrinsics.unsqueeze(0).unsqueeze(0).repeat(1, num_frames, 1, 1) | |
else: | |
generated_intrinsics = initial_intrinsics.unsqueeze(0) | |
return generated_w2cs, generated_intrinsics | |
def _align_inv_depth_to_depth( | |
source_inv_depth: torch.Tensor, | |
target_depth: torch.Tensor, | |
target_mask: torch.Tensor | None = None, | |
) -> torch.Tensor: | |
""" | |
Apply affine transformation to align source inverse depth to target depth. | |
Args: | |
source_inv_depth: Inverse depth map to be aligned. Shape: (H, W). | |
target_depth: Target depth map. Shape: (H, W). | |
target_mask: Mask of valid target pixels. Shape: (H, W). | |
Returns: | |
Aligned Depth map. Shape: (H, W). | |
""" | |
target_inv_depth = 1.0 / target_depth | |
source_mask = source_inv_depth > 0 | |
target_depth_mask = target_depth > 0 | |
if target_mask is None: | |
target_mask = target_depth_mask | |
else: | |
target_mask = torch.logical_and(target_mask > 0, target_depth_mask) | |
# Remove outliers | |
outlier_quantiles = torch.tensor([0.1, 0.9], device=source_inv_depth.device) | |
source_data_low, source_data_high = torch.quantile(source_inv_depth[source_mask], outlier_quantiles) | |
target_data_low, target_data_high = torch.quantile(target_inv_depth[target_mask], outlier_quantiles) | |
source_mask = (source_inv_depth > source_data_low) & (source_inv_depth < source_data_high) | |
target_mask = (target_inv_depth > target_data_low) & (target_inv_depth < target_data_high) | |
mask = torch.logical_and(source_mask, target_mask) | |
source_data = source_inv_depth[mask].view(-1, 1) | |
target_data = target_inv_depth[mask].view(-1, 1) | |
ones = torch.ones((source_data.shape[0], 1), device=source_data.device) | |
source_data_h = torch.cat([source_data, ones], dim=1) | |
transform_matrix = torch.linalg.lstsq(source_data_h, target_data).solution | |
scale, bias = transform_matrix[0, 0], transform_matrix[1, 0] | |
aligned_inv_depth = source_inv_depth * scale + bias | |
return 1.0 / aligned_inv_depth | |
def align_depth( | |
source_depth: torch.Tensor, | |
target_depth: torch.Tensor, | |
target_mask: torch.Tensor, | |
k: torch.Tensor = None, | |
c2w: torch.Tensor = None, | |
alignment_method: str = "rigid", | |
num_iters: int = 100, | |
lambda_arap: float = 0.1, | |
smoothing_kernel_size: int = 3, | |
) -> torch.Tensor: | |
if alignment_method == "rigid": | |
source_inv_depth = 1.0 / source_depth | |
source_depth = _align_inv_depth_to_depth(source_inv_depth, target_depth, target_mask) | |
return source_depth | |
elif alignment_method == "non_rigid": | |
if k is None or c2w is None: | |
raise ValueError("Camera intrinsics (k) and camera-to-world matrix (c2w) are required for non-rigid alignment") | |
source_inv_depth = 1.0 / source_depth | |
source_depth = _align_inv_depth_to_depth(source_inv_depth, target_depth, target_mask) | |
# Initialize scale map | |
sc_map = torch.ones_like(source_depth).float().to(source_depth.device).requires_grad_(True) | |
optimizer = torch.optim.Adam(params=[sc_map], lr=0.001) | |
# Unproject target depth | |
target_unprojected = unproject_points( | |
target_depth.unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions | |
c2w.unsqueeze(0), # Add batch dimension | |
k.unsqueeze(0), # Add batch dimension | |
is_depth=True, | |
mask=target_mask.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions | |
).squeeze(0) # Remove batch dimension | |
# Create smoothing kernel | |
smoothing_kernel = torch.ones( | |
(1, 1, smoothing_kernel_size, smoothing_kernel_size), | |
device=source_depth.device | |
) / (smoothing_kernel_size**2) | |
for _ in range(num_iters): | |
# Unproject scaled source depth | |
source_unprojected = unproject_points( | |
(source_depth * sc_map).unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions | |
c2w.unsqueeze(0), # Add batch dimension | |
k.unsqueeze(0), # Add batch dimension | |
is_depth=True, | |
mask=target_mask.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions | |
).squeeze(0) # Remove batch dimension | |
# Data loss | |
data_loss = torch.abs(source_unprojected[target_mask] - target_unprojected[target_mask]).mean() | |
# Apply smoothing filter to sc_map | |
sc_map_reshaped = sc_map.unsqueeze(0).unsqueeze(0) | |
sc_map_smoothed = F.conv2d( | |
sc_map_reshaped, | |
smoothing_kernel, | |
padding=smoothing_kernel_size // 2 | |
).squeeze(0).squeeze(0) | |
# ARAP loss | |
arap_loss = torch.abs(sc_map_smoothed - sc_map).mean() | |
# Total loss | |
loss = data_loss + lambda_arap * arap_loss | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
return source_depth * sc_map | |
else: | |
raise ValueError(f"Unsupported alignment method: {alignment_method}") | |