# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import math import torch.nn.functional as F from .forward_warp_utils_pytorch import unproject_points def apply_transformation(Bx4x4, another_matrix): B = Bx4x4.shape[0] if another_matrix.dim() == 2: another_matrix = another_matrix.unsqueeze(0).expand(B, -1, -1) # Make another_matrix compatible with batch size transformed_matrix = torch.bmm(Bx4x4, another_matrix) # Shape: (B, 4, 4) return transformed_matrix def look_at_matrix(camera_pos, target, invert_pos=True): """Creates a 4x4 look-at matrix, keeping the camera pointing towards a target.""" forward = (target - camera_pos).float() forward = forward / torch.norm(forward) up = torch.tensor([0.0, 1.0, 0.0], device=camera_pos.device) # assuming Y-up coordinate system right = torch.cross(up, forward) right = right / torch.norm(right) up = torch.cross(forward, right) look_at = torch.eye(4, device=camera_pos.device) look_at[0, :3] = right look_at[1, :3] = up look_at[2, :3] = forward look_at[:3, 3] = (-camera_pos) if invert_pos else camera_pos return look_at def create_horizontal_trajectory( world_to_camera_matrix, center_depth, positive=True, n_steps=13, distance=0.1, device="cuda", axis="x", camera_rotation="center_facing" ): look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) # Spiral motion key points trajectory = [] translation_positions = [] initial_camera_pos = torch.tensor([0, 0, 0], device=device) for i in range(n_steps): if axis == "x": # pos - right x = i * distance * center_depth / n_steps * (1 if positive else -1) y = 0 z = 0 elif axis == "y": # pos - down x = 0 y = i * distance * center_depth / n_steps * (1 if positive else -1) z = 0 elif axis == "z": # pos - in x = 0 y = 0 z = i * distance * center_depth / n_steps * (1 if positive else -1) else: raise ValueError("Axis should be x, y or z") translation_positions.append(torch.tensor([x, y, z], device=device)) for pos in translation_positions: camera_pos = initial_camera_pos + pos if camera_rotation == "trajectory_aligned": _look_at = look_at + pos * 2 elif camera_rotation == "center_facing": _look_at = look_at elif camera_rotation == "no_rotation": _look_at = look_at + pos else: raise ValueError("Camera rotation should be center_facing or trajectory_aligned") view_matrix = look_at_matrix(camera_pos, _look_at) trajectory.append(view_matrix) trajectory = torch.stack(trajectory) return apply_transformation(trajectory, world_to_camera_matrix) def create_spiral_trajectory( world_to_camera_matrix, center_depth, radius_x=0.03, radius_y=0.02, radius_z=0.0, positive=True, camera_rotation="center_facing", n_steps=13, device="cuda", start_from_zero=True, num_circles=1, ): look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) # Spiral motion key points trajectory = [] spiral_positions = [] initial_camera_pos = torch.tensor([0, 0, 0], device=device) # world_to_camera_matrix[:3, 3].clone() example_scale = 1.0 theta_max = 2 * math.pi * num_circles for i in range(n_steps): # theta = 2 * math.pi * i / (n_steps-1) # angle for each point theta = theta_max * i / (n_steps - 1) # angle for each point if start_from_zero: x = radius_x * (math.cos(theta) - 1) * (1 if positive else -1) * (center_depth / example_scale) else: x = radius_x * (math.cos(theta)) * (center_depth / example_scale) y = radius_y * math.sin(theta) * (center_depth / example_scale) z = radius_z * math.sin(theta) * (center_depth / example_scale) spiral_positions.append(torch.tensor([x, y, z], device=device)) for pos in spiral_positions: if camera_rotation == "center_facing": view_matrix = look_at_matrix(initial_camera_pos + pos, look_at) elif camera_rotation == "trajectory_aligned": view_matrix = look_at_matrix(initial_camera_pos + pos, look_at + pos * 2) elif camera_rotation == "no_rotation": view_matrix = look_at_matrix(initial_camera_pos + pos, look_at + pos) else: raise ValueError("Camera rotation should be center_facing, trajectory_aligned or no_rotation") trajectory.append(view_matrix) trajectory = torch.stack(trajectory) return apply_transformation(trajectory, world_to_camera_matrix) def generate_camera_trajectory( trajectory_type: str, initial_w2c: torch.Tensor, # Shape: (4, 4) initial_intrinsics: torch.Tensor, # Shape: (3, 3) num_frames: int, movement_distance: float, camera_rotation: str, center_depth: float = 1.0, device: str = "cuda", ): """ Generates a sequence of camera poses (world-to-camera matrices) and intrinsics for a specified trajectory type. Args: trajectory_type: Type of trajectory (e.g., "left", "right", "up", "down", "zoom_in", "zoom_out"). initial_w2c: Initial world-to-camera matrix (4x4 tensor or num_framesx4x4 tensor). initial_intrinsics: Camera intrinsics matrix (3x3 tensor or num_framesx3x3 tensor). num_frames: Number of frames (steps) in the trajectory. movement_distance: Distance factor for the camera movement. camera_rotation: Type of camera rotation ('center_facing', 'no_rotation', 'trajectory_aligned'). center_depth: Depth of the center point the camera might focus on. device: Computation device ("cuda" or "cpu"). Returns: A tuple (generated_w2cs, generated_intrinsics): - generated_w2cs: Batch of world-to-camera matrices for the trajectory (1, num_frames, 4, 4 tensor). - generated_intrinsics: Batch of camera intrinsics for the trajectory (1, num_frames, 3, 3 tensor). """ if trajectory_type in ["clockwise", "counterclockwise"]: new_w2cs_seq = create_spiral_trajectory( world_to_camera_matrix=initial_w2c, center_depth=center_depth, n_steps=num_frames, positive=trajectory_type == "clockwise", device=device, camera_rotation=camera_rotation, radius_x=movement_distance, radius_y=movement_distance, ) else: if trajectory_type == "left": positive = False axis = "x" elif trajectory_type == "right": positive = True axis = "x" elif trajectory_type == "up": positive = False # Assuming 'up' means camera moves in negative y direction if y points down axis = "y" elif trajectory_type == "down": positive = True # Assuming 'down' means camera moves in positive y direction if y points down axis = "y" elif trajectory_type == "zoom_in": positive = True # Assuming 'zoom_in' means camera moves in positive z direction (forward) axis = "z" elif trajectory_type == "zoom_out": positive = False # Assuming 'zoom_out' means camera moves in negative z direction (backward) axis = "z" else: raise ValueError(f"Unsupported trajectory type: {trajectory_type}") # Generate world-to-camera matrices using create_horizontal_trajectory new_w2cs_seq = create_horizontal_trajectory( world_to_camera_matrix=initial_w2c, center_depth=center_depth, n_steps=num_frames, positive=positive, axis=axis, distance=movement_distance, device=device, camera_rotation=camera_rotation, ) generated_w2cs = new_w2cs_seq.unsqueeze(0) # Shape: [1, num_frames, 4, 4] if initial_intrinsics.dim() == 2: generated_intrinsics = initial_intrinsics.unsqueeze(0).unsqueeze(0).repeat(1, num_frames, 1, 1) else: generated_intrinsics = initial_intrinsics.unsqueeze(0) return generated_w2cs, generated_intrinsics def _align_inv_depth_to_depth( source_inv_depth: torch.Tensor, target_depth: torch.Tensor, target_mask: torch.Tensor | None = None, ) -> torch.Tensor: """ Apply affine transformation to align source inverse depth to target depth. Args: source_inv_depth: Inverse depth map to be aligned. Shape: (H, W). target_depth: Target depth map. Shape: (H, W). target_mask: Mask of valid target pixels. Shape: (H, W). Returns: Aligned Depth map. Shape: (H, W). """ target_inv_depth = 1.0 / target_depth source_mask = source_inv_depth > 0 target_depth_mask = target_depth > 0 if target_mask is None: target_mask = target_depth_mask else: target_mask = torch.logical_and(target_mask > 0, target_depth_mask) # Remove outliers outlier_quantiles = torch.tensor([0.1, 0.9], device=source_inv_depth.device) source_data_low, source_data_high = torch.quantile(source_inv_depth[source_mask], outlier_quantiles) target_data_low, target_data_high = torch.quantile(target_inv_depth[target_mask], outlier_quantiles) source_mask = (source_inv_depth > source_data_low) & (source_inv_depth < source_data_high) target_mask = (target_inv_depth > target_data_low) & (target_inv_depth < target_data_high) mask = torch.logical_and(source_mask, target_mask) source_data = source_inv_depth[mask].view(-1, 1) target_data = target_inv_depth[mask].view(-1, 1) ones = torch.ones((source_data.shape[0], 1), device=source_data.device) source_data_h = torch.cat([source_data, ones], dim=1) transform_matrix = torch.linalg.lstsq(source_data_h, target_data).solution scale, bias = transform_matrix[0, 0], transform_matrix[1, 0] aligned_inv_depth = source_inv_depth * scale + bias return 1.0 / aligned_inv_depth def align_depth( source_depth: torch.Tensor, target_depth: torch.Tensor, target_mask: torch.Tensor, k: torch.Tensor = None, c2w: torch.Tensor = None, alignment_method: str = "rigid", num_iters: int = 100, lambda_arap: float = 0.1, smoothing_kernel_size: int = 3, ) -> torch.Tensor: if alignment_method == "rigid": source_inv_depth = 1.0 / source_depth source_depth = _align_inv_depth_to_depth(source_inv_depth, target_depth, target_mask) return source_depth elif alignment_method == "non_rigid": if k is None or c2w is None: raise ValueError("Camera intrinsics (k) and camera-to-world matrix (c2w) are required for non-rigid alignment") source_inv_depth = 1.0 / source_depth source_depth = _align_inv_depth_to_depth(source_inv_depth, target_depth, target_mask) # Initialize scale map sc_map = torch.ones_like(source_depth).float().to(source_depth.device).requires_grad_(True) optimizer = torch.optim.Adam(params=[sc_map], lr=0.001) # Unproject target depth target_unprojected = unproject_points( target_depth.unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions c2w.unsqueeze(0), # Add batch dimension k.unsqueeze(0), # Add batch dimension is_depth=True, mask=target_mask.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions ).squeeze(0) # Remove batch dimension # Create smoothing kernel smoothing_kernel = torch.ones( (1, 1, smoothing_kernel_size, smoothing_kernel_size), device=source_depth.device ) / (smoothing_kernel_size**2) for _ in range(num_iters): # Unproject scaled source depth source_unprojected = unproject_points( (source_depth * sc_map).unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions c2w.unsqueeze(0), # Add batch dimension k.unsqueeze(0), # Add batch dimension is_depth=True, mask=target_mask.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions ).squeeze(0) # Remove batch dimension # Data loss data_loss = torch.abs(source_unprojected[target_mask] - target_unprojected[target_mask]).mean() # Apply smoothing filter to sc_map sc_map_reshaped = sc_map.unsqueeze(0).unsqueeze(0) sc_map_smoothed = F.conv2d( sc_map_reshaped, smoothing_kernel, padding=smoothing_kernel_size // 2 ).squeeze(0).squeeze(0) # ARAP loss arap_loss = torch.abs(sc_map_smoothed - sc_map).mean() # Total loss loss = data_loss + lambda_arap * arap_loss optimizer.zero_grad() loss.backward() optimizer.step() return source_depth * sc_map else: raise ValueError(f"Unsupported alignment method: {alignment_method}")