# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from einops import rearrange from cosmos_predict1.diffusion.inference.forward_warp_utils_pytorch import ( forward_warp, reliable_depth_mask_range_batch, unproject_points, ) from cosmos_predict1.diffusion.inference.camera_utils import align_depth class Cache3D_Base: def __init__( self, input_image, input_depth, input_w2c, input_intrinsics, input_mask=None, input_format=None, input_points=None, weight_dtype=torch.float32, is_depth=True, device="cuda", filter_points_threshold=1.0, foreground_masking=False, ): """ input_image: Tensor with varying dimensions. input_format: List of dimension labels corresponding to input_image's dimensions. E.g., ['B', 'C', 'H', 'W'], ['B', 'F', 'C', 'H', 'W'], etc. """ self.weight_dtype = weight_dtype self.is_depth = is_depth self.device = device self.filter_points_threshold = filter_points_threshold self.foreground_masking = foreground_masking if input_format is None: assert input_image.dim() == 4 input_format = ["B", "C", "H", "W"] # Map dimension names to their indices in input_image format_to_indices = {dim: idx for idx, dim in enumerate(input_format)} input_shape = input_image.shape if input_mask is not None: input_image = torch.cat([input_image, input_mask], dim=format_to_indices.get("C")) # B (batch size), F (frame count), N dimensions: no aggregation during warping. # Only broadcasting over F to match the target w2c. # V: aggregate via concatenation or duster B = input_shape[format_to_indices.get("B", 0)] if "B" in format_to_indices else 1 # batch F = input_shape[format_to_indices.get("F", 0)] if "F" in format_to_indices else 1 # frame N = input_shape[format_to_indices.get("N", 0)] if "N" in format_to_indices else 1 # buffer V = input_shape[format_to_indices.get("V", 0)] if "V" in format_to_indices else 1 # view H = input_shape[format_to_indices.get("H", 0)] if "H" in format_to_indices else None W = input_shape[format_to_indices.get("W", 0)] if "W" in format_to_indices else None # Desired dimension order desired_dims = ["B", "F", "N", "V", "C", "H", "W"] # Build permute order based on input_format permute_order = [] for dim in desired_dims: idx = format_to_indices.get(dim) if idx is not None: permute_order.append(idx) else: # Placeholder for dimensions to be added later permute_order.append(None) # Remove None values for permute operation permute_indices = [idx for idx in permute_order if idx is not None] input_image = input_image.permute(*permute_indices) # Insert dimensions of size 1 where necessary for i, idx in enumerate(permute_order): if idx is None: input_image = input_image.unsqueeze(i) # Now input_image has the shape B x F x N x V x C x H x W if input_mask is not None: self.input_image, self.input_mask = input_image[:, :, :, :, :3], input_image[:, :, :, :, 3:] self.input_mask = self.input_mask.to("cpu") else: self.input_mask = None self.input_image = input_image self.input_image = self.input_image.to(weight_dtype).to("cpu") if input_points is not None: self.input_points = input_points.reshape(B, F, N, V, H, W, 3).to("cpu") self.input_depth = None else: input_depth = torch.nan_to_num(input_depth, nan=100) input_depth = torch.clamp(input_depth, min=0, max=100) if weight_dtype == torch.float16: input_depth = torch.clamp(input_depth, max=70) self.input_points = ( self._compute_input_points( input_depth.reshape(-1, 1, H, W), input_w2c.reshape(-1, 4, 4), input_intrinsics.reshape(-1, 3, 3), ) .to(weight_dtype) .reshape(B, F, N, V, H, W, 3) .to("cpu") ) self.input_depth = input_depth if self.filter_points_threshold < 1.0 and input_depth is not None: input_depth = input_depth.reshape(-1, 1, H, W) depth_mask = reliable_depth_mask_range_batch(input_depth, ratio_thresh=self.filter_points_threshold).reshape(B, F, N, V, 1, H, W) if self.input_mask is None: self.input_mask = depth_mask.to("cpu") else: self.input_mask = self.input_mask * depth_mask.to(self.input_mask.device) self.boundary_mask = None if foreground_masking: input_depth = input_depth.reshape(-1, 1, H, W) depth_mask = reliable_depth_mask_range_batch(input_depth) self.boundary_mask = (~depth_mask).reshape(B, F, N, V, 1, H, W).to("cpu") def _compute_input_points(self, input_depth, input_w2c, input_intrinsics): input_points = unproject_points( input_depth, input_w2c, input_intrinsics, is_depth=self.is_depth, ) return input_points def update_cache(self): raise NotImplementedError def input_frame_count(self) -> int: return self.input_image.shape[1] def render_cache(self, target_w2cs, target_intrinsics, render_depth=False, start_frame_idx=0): bs, F_target, _, _ = target_w2cs.shape B, F, N, V, C, H, W = self.input_image.shape assert bs == B target_w2cs = target_w2cs.reshape(B, F_target, 1, 4, 4).expand(B, F_target, N, 4, 4).reshape(-1, 4, 4) target_intrinsics = ( target_intrinsics.reshape(B, F_target, 1, 3, 3).expand(B, F_target, N, 3, 3).reshape(-1, 3, 3) ) first_images = rearrange(self.input_image[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, C, H, W), "B F N V C H W-> (B F N) V C H W").to(self.device) first_points = rearrange( self.input_points[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, H, W, 3), "B F N V H W C-> (B F N) V H W C" ).to(self.device) first_masks = rearrange( self.input_mask[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, 1, H, W), "B F N V C H W-> (B F N) V C H W" ).to(self.device) if self.input_mask is not None else None boundary_masks = rearrange( self.boundary_mask.expand(B, F_target, N, V, 1, H, W), "B F N V C H W-> (B F N) V C H W" ) if self.boundary_mask is not None else None if first_images.shape[1] == 1: warp_chunk_size = 2 rendered_warp_images = [] rendered_warp_masks = [] rendered_warp_depth = [] rendered_warped_flows = [] first_images = first_images.squeeze(1) first_points = first_points.squeeze(1) first_masks = first_masks.squeeze(1) if first_masks is not None else None for i in range(0, first_images.shape[0], warp_chunk_size): ( rendered_warp_images_chunk, rendered_warp_masks_chunk, rendered_warp_depth_chunk, rendered_warped_flows_chunk, ) = forward_warp( first_images[i : i + warp_chunk_size], mask1=first_masks[i : i + warp_chunk_size] if first_masks is not None else None, depth1=None, transformation1=None, transformation2=target_w2cs[i : i + warp_chunk_size], intrinsic1=target_intrinsics[i : i + warp_chunk_size], intrinsic2=target_intrinsics[i : i + warp_chunk_size], render_depth=render_depth, world_points1=first_points[i : i + warp_chunk_size], foreground_masking=self.foreground_masking, boundary_mask=boundary_masks[i : i + warp_chunk_size, 0, 0] if boundary_masks is not None else None ) rendered_warp_images.append(rendered_warp_images_chunk) rendered_warp_masks.append(rendered_warp_masks_chunk) rendered_warp_depth.append(rendered_warp_depth_chunk) rendered_warped_flows.append(rendered_warped_flows_chunk) rendered_warp_images = torch.cat(rendered_warp_images, dim=0) rendered_warp_masks = torch.cat(rendered_warp_masks, dim=0) if render_depth: rendered_warp_depth = torch.cat(rendered_warp_depth, dim=0) rendered_warped_flows = torch.cat(rendered_warped_flows, dim=0) else: raise NotImplementedError pixels = rearrange(rendered_warp_images, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N) masks = rearrange(rendered_warp_masks, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N) if render_depth: pixels = rearrange(rendered_warp_depth, "(b f n) h w -> b f n h w", b=bs, f=F_target, n=N) return pixels, masks class Cache3D_Buffer(Cache3D_Base): def __init__(self, frame_buffer_max=0, noise_aug_strength=0, generator=None, **kwargs): super().__init__(**kwargs) self.frame_buffer_max = frame_buffer_max self.noise_aug_strength = noise_aug_strength self.generator = generator def update_cache(self, new_image, new_depth, new_w2c, new_mask=None, new_intrinsics=None, depth_alignment=True, alignment_method="non_rigid"): # 3D cache new_image = new_image.to(self.weight_dtype).to(self.device) new_depth = new_depth.to(self.weight_dtype).to(self.device) new_w2c = new_w2c.to(self.weight_dtype).to(self.device) if new_intrinsics is not None: new_intrinsics = new_intrinsics.to(self.weight_dtype).to(self.device) new_depth = torch.nan_to_num(new_depth, nan=1e4) new_depth = torch.clamp(new_depth, min=0, max=1e4) if depth_alignment: target_depth, target_mask = self.render_cache( new_w2c.unsqueeze(1), new_intrinsics.unsqueeze(1), render_depth=True ) target_depth, target_mask = target_depth[:, :, 0], target_mask[:, :, 0] if alignment_method == "rigid": new_depth = ( align_depth( new_depth.squeeze(), target_depth.squeeze(), target_mask.bool().squeeze(), ) .reshape_as(new_depth) .detach() ) elif alignment_method == "non_rigid": with torch.enable_grad(): new_depth = ( align_depth( new_depth.squeeze(), target_depth.squeeze(), target_mask.bool().squeeze(), k=new_intrinsics.squeeze(), c2w=torch.inverse(new_w2c.squeeze()), alignment_method="non_rigid", num_iters=100, lambda_arap=0.1, smoothing_kernel_size=3, ) .reshape_as(new_depth) .detach() ) else: raise NotImplementedError new_points = unproject_points(new_depth, new_w2c, new_intrinsics, is_depth=self.is_depth).cpu() new_image = new_image.cpu() if self.filter_points_threshold < 1.0: B, F, N, V, C, H, W = self.input_image.shape new_depth = new_depth.reshape(-1, 1, H, W) depth_mask = reliable_depth_mask_range_batch(new_depth, ratio_thresh=self.filter_points_threshold).reshape(B, 1, H, W) if new_mask is None: new_mask = depth_mask.to("cpu") else: new_mask = new_mask * depth_mask.to(new_mask.device) if new_mask is not None: new_mask = new_mask.cpu() if self.frame_buffer_max > 1: # newest frame first if self.input_image.shape[2] < self.frame_buffer_max: self.input_image = torch.cat([new_image[:, None, None, None], self.input_image], 2) self.input_points = torch.cat([new_points[:, None, None, None], self.input_points], 2) if self.input_mask is not None: self.input_mask = torch.cat([new_mask[:, None, None, None], self.input_mask], 2) else: self.input_image[:, :, 0] = new_image[:, None, None] self.input_points[:, :, 0] = new_points[:, None, None] if self.input_mask is not None: self.input_mask[:, :, 0] = new_mask[:, None, None] else: self.input_image = new_image[:, None, None, None] self.input_points = new_points[:, None, None, None] def render_cache( self, target_w2cs, target_intrinsics, render_depth: bool = False, start_frame_idx: int = 0, # For consistency with Cache4D ): assert start_frame_idx == 0, "start_frame_idx must be 0 for Cache3D_Buffer" output_device = target_w2cs.device target_w2cs = target_w2cs.to(self.weight_dtype).to(self.device) target_intrinsics = target_intrinsics.to(self.weight_dtype).to(self.device) pixels, masks = super().render_cache( target_w2cs, target_intrinsics, render_depth ) if not render_depth: noise = torch.randn(pixels.shape, generator=self.generator, device=pixels.device, dtype=pixels.dtype) per_buffer_noise = ( torch.arange(start=pixels.shape[2] - 1, end=-1, step=-1, device=pixels.device) * self.noise_aug_strength ) pixels = pixels + noise * per_buffer_noise.reshape(1, 1, -1, 1, 1, 1) # B, F, N, C, H, W return pixels.to(output_device), masks.to(output_device) class Cache4D(Cache3D_Base): def __init__(self, **kwargs): super().__init__(**kwargs) def update_cache(self, **kwargs): raise NotImplementedError def render_cache(self, target_w2cs, target_intrinsics, render_depth=False, start_frame_idx=0): rendered_warp_images, rendered_warp_masks = super().render_cache(target_w2cs, target_intrinsics, render_depth, start_frame_idx) return rendered_warp_images, rendered_warp_masks