# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import cv2 from moge.model.v1 import MoGeModel import torch import numpy as np from cosmos_predict1.diffusion.inference.inference_utils import ( add_common_arguments, check_input_frames, validate_args, ) from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline from cosmos_predict1.utils import log, misc from cosmos_predict1.utils.io import read_prompts_from_file, save_video from cosmos_predict1.diffusion.inference.cache_3d import Cache3D_Buffer from cosmos_predict1.diffusion.inference.camera_utils import generate_camera_trajectory import torch.nn.functional as F torch.enable_grad(False) def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Video to world generation demo script") # Add common arguments add_common_arguments(parser) parser.add_argument( "--prompt_upsampler_dir", type=str, default="Pixtral-12B", help="Prompt upsampler weights directory relative to checkpoint_dir", ) # TODO: do we need this? parser.add_argument( "--input_image_path", type=str, help="Input image path for generating a single video", ) parser.add_argument( "--trajectory", type=str, choices=[ "left", "right", "up", "down", "zoom_in", "zoom_out", "clockwise", "counterclockwise", "none", ], default="left", help="Select a trajectory type from the available options (default: original)", ) parser.add_argument( "--camera_rotation", type=str, choices=["center_facing", "no_rotation", "trajectory_aligned"], default="center_facing", help="Controls camera rotation during movement: center_facing (rotate to look at center), no_rotation (keep orientation), or trajectory_aligned (rotate in the direction of movement)", ) parser.add_argument( "--movement_distance", type=float, default=0.3, help="Distance of the camera from the center of the scene", ) parser.add_argument( "--noise_aug_strength", type=float, default=0.0, help="Strength of noise augmentation on warped frames", ) parser.add_argument( "--save_buffer", action="store_true", help="If set, save the warped images (buffer) side by side with the output video.", ) parser.add_argument( "--filter_points_threshold", type=float, default=0.05, help="If set, filter the points continuity of the warped images.", ) parser.add_argument( "--foreground_masking", action="store_true", help="If set, use foreground masking for the warped images.", ) return parser def parse_arguments() -> argparse.Namespace: parser = create_parser() return parser.parse_args() def validate_args(args): assert args.num_video_frames is not None, "num_video_frames must be provided" assert (args.num_video_frames - 1) % 120 == 0, "num_video_frames must be 121, 241, 361, ... (N*120+1)" def _predict_moge_depth(current_image_path: str | np.ndarray, target_h: int, target_w: int, device: torch.device, moge_model: MoGeModel): """Handles MoGe depth prediction for a single image. If the image is directly provided as a NumPy array, it should have shape [H, W, C], where the channels are RGB and the pixel values are in [0..255]. """ if isinstance(current_image_path, str): input_image_bgr = cv2.imread(current_image_path) if input_image_bgr is None: raise FileNotFoundError(f"Input image not found: {current_image_path}") input_image_rgb = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB) else: input_image_rgb = current_image_path del current_image_path depth_pred_h, depth_pred_w = 720, 1280 input_image_for_depth_resized = cv2.resize(input_image_rgb, (depth_pred_w, depth_pred_h)) input_image_for_depth_tensor_chw = torch.tensor(input_image_for_depth_resized / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1) moge_output_full = moge_model.infer(input_image_for_depth_tensor_chw) moge_depth_hw_full = moge_output_full["depth"] moge_intrinsics_33_full_normalized = moge_output_full["intrinsics"] moge_mask_hw_full = moge_output_full["mask"] moge_depth_hw_full = torch.where(moge_mask_hw_full==0, torch.tensor(1000.0, device=moge_depth_hw_full.device), moge_depth_hw_full) moge_intrinsics_33_full_pixel = moge_intrinsics_33_full_normalized.clone() moge_intrinsics_33_full_pixel[0, 0] *= depth_pred_w moge_intrinsics_33_full_pixel[1, 1] *= depth_pred_h moge_intrinsics_33_full_pixel[0, 2] *= depth_pred_w moge_intrinsics_33_full_pixel[1, 2] *= depth_pred_h # Calculate scaling factor for height height_scale_factor = target_h / depth_pred_h width_scale_factor = target_w / depth_pred_w # Resize depth map, mask, and image tensor # Resizing depth: (H, W) -> (1, 1, H, W) for interpolate, then squeeze moge_depth_hw = F.interpolate( moge_depth_hw_full.unsqueeze(0).unsqueeze(0), size=(target_h, target_w), mode='bilinear', align_corners=False ).squeeze(0).squeeze(0) # Resizing mask: (H, W) -> (1, 1, H, W) for interpolate, then squeeze moge_mask_hw = F.interpolate( moge_mask_hw_full.unsqueeze(0).unsqueeze(0).to(torch.float32), size=(target_h, target_w), mode='nearest', # Using nearest neighbor for binary mask ).squeeze(0).squeeze(0).to(torch.bool) # Resizing image tensor: (C, H, W) -> (1, C, H, W) for interpolate, then squeeze input_image_tensor_chw_target_res = F.interpolate( input_image_for_depth_tensor_chw.unsqueeze(0), size=(target_h, target_w), mode='bilinear', align_corners=False ).squeeze(0) moge_image_b1chw_float = input_image_tensor_chw_target_res.unsqueeze(0).unsqueeze(1) * 2 - 1 moge_intrinsics_33 = moge_intrinsics_33_full_pixel.clone() # Adjust intrinsics for resized height moge_intrinsics_33[1, 1] *= height_scale_factor # fy moge_intrinsics_33[1, 2] *= height_scale_factor # cy moge_intrinsics_33[0, 0] *= width_scale_factor # fx moge_intrinsics_33[0, 2] *= width_scale_factor # cx moge_depth_b11hw = moge_depth_hw.unsqueeze(0).unsqueeze(0).unsqueeze(0) moge_depth_b11hw = torch.nan_to_num(moge_depth_b11hw, nan=1e4) moge_depth_b11hw = torch.clamp(moge_depth_b11hw, min=0, max=1e4) moge_mask_b11hw = moge_mask_hw.unsqueeze(0).unsqueeze(0).unsqueeze(0) # Prepare initial intrinsics [B, 1, 3, 3] moge_intrinsics_b133 = moge_intrinsics_33.unsqueeze(0).unsqueeze(0) initial_w2c_44 = torch.eye(4, dtype=torch.float32, device=device) moge_initial_w2c_b144 = initial_w2c_44.unsqueeze(0).unsqueeze(0) return ( moge_image_b1chw_float, moge_depth_b11hw, moge_mask_b11hw, moge_initial_w2c_b144, moge_intrinsics_b133, ) def _predict_moge_depth_from_tensor( image_tensor_chw_0_1: torch.Tensor, # Shape (C, H_input, W_input), range [0,1] moge_model: MoGeModel ): """Handles MoGe depth prediction from an image tensor.""" moge_output_full = moge_model.infer(image_tensor_chw_0_1) moge_depth_hw_full = moge_output_full["depth"] # (moge_inf_h, moge_inf_w) moge_mask_hw_full = moge_output_full["mask"] # (moge_inf_h, moge_inf_w) moge_depth_11hw = moge_depth_hw_full.unsqueeze(0).unsqueeze(0) moge_depth_11hw = torch.nan_to_num(moge_depth_11hw, nan=1e4) moge_depth_11hw = torch.clamp(moge_depth_11hw, min=0, max=1e4) moge_mask_11hw = moge_mask_hw_full.unsqueeze(0).unsqueeze(0) moge_depth_11hw = torch.where(moge_mask_11hw==0, torch.tensor(1000.0, device=moge_depth_11hw.device), moge_depth_11hw) return moge_depth_11hw, moge_mask_11hw def demo(args): """Run video-to-world generation demo. This function handles the main video-to-world generation pipeline, including: - Setting up the random seed for reproducibility - Initializing the generation pipeline with the provided configuration - Processing single or multiple prompts/images/videos from input - Generating videos from prompts and images/videos - Saving the generated videos and corresponding prompts to disk Args: cfg (argparse.Namespace): Configuration namespace containing: - Model configuration (checkpoint paths, model settings) - Generation parameters (guidance, steps, dimensions) - Input/output settings (prompts/images/videos, save paths) - Performance options (model offloading settings) The function will save: - Generated MP4 video files - Text files containing the processed prompts If guardrails block the generation, a critical log message is displayed and the function continues to the next prompt if available. """ misc.set_random_seed(args.seed) inference_type = "video2world" validate_args(args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.num_gpus > 1: from megatron.core import parallel_state from cosmos_predict1.utils import distributed distributed.init() parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) process_group = parallel_state.get_context_parallel_group() # Initialize video2world generation model pipeline pipeline = Gen3cPipeline( inference_type=inference_type, checkpoint_dir=args.checkpoint_dir, checkpoint_name="Gen3C-Cosmos-7B", prompt_upsampler_dir=args.prompt_upsampler_dir, enable_prompt_upsampler=not args.disable_prompt_upsampler, offload_network=args.offload_diffusion_transformer, offload_tokenizer=args.offload_tokenizer, offload_text_encoder_model=args.offload_text_encoder_model, offload_prompt_upsampler=args.offload_prompt_upsampler, offload_guardrail_models=args.offload_guardrail_models, disable_guardrail=args.disable_guardrail, guidance=args.guidance, num_steps=args.num_steps, height=args.height, width=args.width, fps=args.fps, num_video_frames=121, seed=args.seed, ) frame_buffer_max = pipeline.model.frame_buffer_max generator = torch.Generator(device=device).manual_seed(args.seed) sample_n_frames = pipeline.model.chunk_size moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) if args.num_gpus > 1: pipeline.model.net.enable_context_parallel(process_group) # Handle multiple prompts if prompt file is provided if args.batch_input_path: log.info(f"Reading batch inputs from path: {args.batch_input_path}") prompts = read_prompts_from_file(args.batch_input_path) else: # Single prompt case prompts = [{"prompt": args.prompt, "visual_input": args.input_image_path}] os.makedirs(os.path.dirname(args.video_save_folder), exist_ok=True) for i, input_dict in enumerate(prompts): current_prompt = input_dict.get("prompt", None) if current_prompt is None and args.disable_prompt_upsampler: log.critical("Prompt is missing, skipping world generation.") continue current_image_path = input_dict.get("visual_input", None) if current_image_path is None: log.critical("Visual input is missing, skipping world generation.") continue # Check input frames if not check_input_frames(current_image_path, 1): print(f"Input image {current_image_path} is not valid, skipping.") continue # load image, predict depth and initialize 3D cache ( moge_image_b1chw_float, moge_depth_b11hw, moge_mask_b11hw, moge_initial_w2c_b144, moge_intrinsics_b133, ) = _predict_moge_depth( current_image_path, args.height, args.width, device, moge_model ) cache = Cache3D_Buffer( frame_buffer_max=frame_buffer_max, generator=generator, noise_aug_strength=args.noise_aug_strength, input_image=moge_image_b1chw_float[:, 0].clone(), # [B, C, H, W] input_depth=moge_depth_b11hw[:, 0], # [B, 1, H, W] # input_mask=moge_mask_b11hw[:, 0], # [B, 1, H, W] input_w2c=moge_initial_w2c_b144[:, 0], # [B, 4, 4] input_intrinsics=moge_intrinsics_b133[:, 0],# [B, 3, 3] filter_points_threshold=args.filter_points_threshold, foreground_masking=args.foreground_masking, ) initial_cam_w2c_for_traj = moge_initial_w2c_b144[0, 0] initial_cam_intrinsics_for_traj = moge_intrinsics_b133[0, 0] # Generate camera trajectory using the new utility function try: generated_w2cs, generated_intrinsics = generate_camera_trajectory( trajectory_type=args.trajectory, initial_w2c=initial_cam_w2c_for_traj, initial_intrinsics=initial_cam_intrinsics_for_traj, num_frames=args.num_video_frames, movement_distance=args.movement_distance, camera_rotation=args.camera_rotation, center_depth=1.0, device=device.type, ) except (ValueError, NotImplementedError) as e: log.critical(f"Failed to generate trajectory: {e}") continue log.info(f"Generating 0 - {sample_n_frames} frames") rendered_warp_images, rendered_warp_masks = cache.render_cache( generated_w2cs[:, 0:sample_n_frames], generated_intrinsics[:, 0:sample_n_frames], ) all_rendered_warps = [] if args.save_buffer: all_rendered_warps.append(rendered_warp_images.clone().cpu()) # Generate video generated_output = pipeline.generate( prompt=current_prompt, image_path=current_image_path, negative_prompt=args.negative_prompt, rendered_warp_images=rendered_warp_images, rendered_warp_masks=rendered_warp_masks, ) if generated_output is None: log.critical("Guardrail blocked video2world generation.") continue video, prompt = generated_output num_ar_iterations = (generated_w2cs.shape[1] - 1) // (sample_n_frames - 1) for num_iter in range(1, num_ar_iterations): start_frame_idx = num_iter * (sample_n_frames - 1) # Overlap by 1 frame end_frame_idx = start_frame_idx + sample_n_frames log.info(f"Generating {start_frame_idx} - {end_frame_idx} frames") last_frame_hwc_0_255 = torch.tensor(video[-1], device=device) pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] pred_depth, pred_mask = _predict_moge_depth_from_tensor( pred_image_for_depth_chw_0_1, moge_model ) cache.update_cache( new_image=pred_image_for_depth_chw_0_1.unsqueeze(0) * 2 - 1, # (B,C,H,W) range [-1,1] new_depth=pred_depth, # (1,1,H,W) # new_mask=pred_mask, # (1,1,H,W) new_w2c=generated_w2cs[:, start_frame_idx], new_intrinsics=generated_intrinsics[:, start_frame_idx], ) current_segment_w2cs = generated_w2cs[:, start_frame_idx:end_frame_idx] current_segment_intrinsics = generated_intrinsics[:, start_frame_idx:end_frame_idx] rendered_warp_images, rendered_warp_masks = cache.render_cache( current_segment_w2cs, current_segment_intrinsics, ) if args.save_buffer: all_rendered_warps.append(rendered_warp_images[:, 1:].clone().cpu()) pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 # (B,C,T,H,W), range [-1,1] generated_output = pipeline.generate( prompt=current_prompt, image_path=pred_image_for_depth_bcthw_minus1_1, negative_prompt=args.negative_prompt, rendered_warp_images=rendered_warp_images, rendered_warp_masks=rendered_warp_masks, ) video_new, prompt = generated_output video = np.concatenate([video, video_new[1:]], axis=0) # Final video processing final_video_to_save = video final_width = args.width if args.save_buffer and all_rendered_warps: squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] # Each is (T_chunk, n_i, C, H, W) if squeezed_warps: n_max = max(t.shape[1] for t in squeezed_warps) padded_t_list = [] for sq_t in squeezed_warps: # sq_t shape: (T_chunk, n_i, C, H, W) current_n_i = sq_t.shape[1] padding_needed_dim1 = n_max - current_n_i pad_spec = (0,0, # W 0,0, # H 0,0, # C 0,padding_needed_dim1, # n_i 0,0) # T_chunk padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0) padded_t_list.append(padded_t) full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0) T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4) buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim) buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0 buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8) buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1)) final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2) final_width = args.width * (1 + n_max) log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}") else: log.info("No warp buffers to save.") video_save_path = os.path.join( args.video_save_folder, f"{i if args.batch_input_path else args.video_save_name}.mp4" ) os.makedirs(os.path.dirname(video_save_path), exist_ok=True) # Save video save_video( video=final_video_to_save, fps=args.fps, H=args.height, W=final_width, video_save_quality=5, video_save_path=video_save_path, ) log.info(f"Saved video to {video_save_path}") # clean up properly if args.num_gpus > 1: parallel_state.destroy_model_parallel() import torch.distributed as dist dist.destroy_process_group() if __name__ == "__main__": args = parse_arguments() if args.prompt is None: args.prompt = "" args.disable_guardrail = True args.disable_prompt_upsampler = True demo(args)