harry900000's picture
add cosmos-tranfer1/ into repo
226c7c9
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import json
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning
import sys
from io import BytesIO
import torch
from cosmos_transfer1.checkpoints import (
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH,
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH,
)
from cosmos_transfer1.diffusion.inference.inference_utils import (
default_model_names,
load_controlnet_specs,
valid_hint_keys,
)
from cosmos_transfer1.diffusion.inference.preprocessors import Preprocessors
from cosmos_transfer1.diffusion.inference.world_generation_pipeline import (
DiffusionControl2WorldMultiviewGenerationPipeline,
)
from cosmos_transfer1.utils import log, misc
from cosmos_transfer1.utils.io import save_video
torch.enable_grad(False)
from cosmos_transfer1.checkpoints import (
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH,
BASE_7B_CHECKPOINT_PATH,
DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH,
VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH,
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH,
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH,
SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH,
SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH,
)
from cosmos_transfer1.diffusion.model.model_ctrl import VideoDiffusionModelWithCtrl, VideoDiffusionT2VModelWithCtrl
from cosmos_transfer1.diffusion.model.model_multi_camera_ctrl import MultiVideoDiffusionModelWithCtrl
MODEL_CLASS_DICT = {
BASE_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl,
EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl,
VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl,
DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl,
SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl,
KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl,
UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl,
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: VideoDiffusionT2VModelWithCtrl,
HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionT2VModelWithCtrl,
LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionT2VModelWithCtrl,
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: MultiVideoDiffusionModelWithCtrl,
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl,
SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl,
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: MultiVideoDiffusionModelWithCtrl,
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl,
SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl,
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl,
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl,
}
MODEL_NAME_DICT = {
BASE_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3",
EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3",
VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_vis_block3",
DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_depth_block3",
KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_keypoint_block3",
SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3",
UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_upscale_block3",
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_hdmap_block3",
HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_hdmap_block3",
LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_lidar_block3",
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_block3",
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_sv2mv_v2w_57frames_control_input_hdmap_block3",
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_block3",
SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_lidar_block3",
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_waymo_block3",
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_v2w_57frames_control_input_hdmap_waymo_block3",
}
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Control to world generation demo script", conflict_handler="resolve")
parser.add_argument(
"--prompt",
type=str,
default="The video captures a stunning, photorealistic scene with remarkable attention to detail, giving it a lifelike appearance that is almost indistinguishable from reality. It appears to be from a high-budget 4K movie, showcasing ultra-high-definition quality with impeccable resolution.",
help="prompt which the sampled video condition on",
)
parser.add_argument(
"--prompt_left",
type=str,
default="The video is captured from a camera mounted on a car. The camera is facing to the left. ",
help="Text prompt for generating left camera view video",
)
parser.add_argument(
"--prompt_right",
type=str,
default="The video is captured from a camera mounted on a car. The camera is facing to the right.",
help="Text prompt for generating right camera view video",
)
parser.add_argument(
"--prompt_back",
type=str,
default="The video is captured from a camera mounted on a car. The camera is facing backwards.",
help="Text prompt for generating rear camera view video",
)
parser.add_argument(
"--prompt_back_left",
type=str,
default="The video is captured from a camera mounted on a car. The camera is facing the rear left side.",
help="Text prompt for generating left camera view video",
)
parser.add_argument(
"--prompt_back_right",
type=str,
default="The video is captured from a camera mounted on a car. The camera is facing the rear right side.",
help="Text prompt for generating right camera view video",
)
parser.add_argument(
"--view_condition_video",
type=str,
default="",
help="We require that only a single condition view is specified and this video is treated as conditioning for that view. "
"This video/videos should have the same duration as control videos",
)
parser.add_argument(
"--initial_condition_video",
type=str,
default="",
help="Can be either a path to a mp4 or a directory. If it is a mp4, we assume"
"that it is a video temporally concatenated with the same number of views as the model. "
"If it is a directory, we assume that the file names evaluate to integers that correspond to a view index,"
" e.g. '000.mp4', '003.mp4', '004.mp4'."
"This video/videos should have at least num_input_frames number of frames for each view. Frames will be taken from the back"
"of the video(s) if the duration of the video in each view exceed num_input_frames",
)
parser.add_argument(
"--num_input_frames",
type=int,
default=1,
help="Number of conditional frames for long video generation, not used in t2w",
choices=[1, 9],
)
parser.add_argument(
"--controlnet_specs",
type=str,
help="Path to JSON file specifying multicontrolnet configurations",
required=True,
)
parser.add_argument(
"--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints"
)
parser.add_argument(
"--tokenizer_dir",
type=str,
default="Cosmos-Tokenize1-CV8x8x8-720p",
help="Tokenizer weights directory relative to checkpoint_dir",
)
parser.add_argument(
"--video_save_name",
type=str,
default="output",
help="Output filename for generating a single video",
)
parser.add_argument(
"--video_save_folder",
type=str,
default="outputs/",
help="Output folder for generating a batch of videos",
)
parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps")
parser.add_argument("--guidance", type=float, default=5, help="Classifier-free guidance scale value")
parser.add_argument("--fps", type=int, default=24, help="FPS of the output video")
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument("--n_clip_max", type=int, default=-1, help="Maximum number of video extension loop")
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs used to run inference in parallel.")
parser.add_argument(
"--offload_diffusion_transformer",
action="store_true",
help="Offload DiT after inference",
)
parser.add_argument(
"--offload_text_encoder_model",
action="store_true",
help="Offload text encoder model after inference",
)
parser.add_argument(
"--offload_guardrail_models",
action="store_true",
help="Offload guardrail models after inference",
)
parser.add_argument(
"--upsample_prompt",
action="store_true",
help="Upsample prompt using Pixtral upsampler model",
)
parser.add_argument(
"--offload_prompt_upsampler",
action="store_true",
help="Offload prompt upsampler model after inference",
)
parser.add_argument(
"--waymo_example",
type=bool,
default=False,
help="Set to true when using post-trained checkpoint from the Waymo post-training example",
)
cmd_args = parser.parse_args()
# Load and parse JSON input
control_inputs, json_args = load_controlnet_specs(cmd_args)
control_inputs.update(json_args)
log.info(f"control_inputs: {json.dumps(control_inputs, indent=4)}")
log.info(f"args in json: {json.dumps(json_args, indent=4)}")
# if parameters not set on command line, use the ones from the controlnet_specs
# if both not set use command line defaults
for key in json_args:
if f"--{key}" not in sys.argv:
setattr(cmd_args, key, json_args[key])
log.info(f"final args: {json.dumps(vars(cmd_args), indent=4)}")
return cmd_args, control_inputs
def validate_controlnet_specs(cfg, controlnet_specs):
"""
Load and validate controlnet specifications from a JSON file.
Args:
json_path (str): Path to the JSON file containing controlnet specs.
checkpoint_dir (str): Base directory for checkpoint files.
Returns:
Dict[str, Any]: Validated and processed controlnet specifications.
"""
checkpoint_dir = cfg.checkpoint_dir
for hint_key, config in controlnet_specs.items():
if hint_key not in list(valid_hint_keys) + ["prompts", "view_condition_video"]:
raise ValueError(f"Invalid hint_key: {hint_key}. Must be one of {valid_hint_keys}")
if hint_key in valid_hint_keys:
if "ckpt_path" not in config:
log.info(f"No checkpoint path specified for {hint_key}. Using default.")
config["ckpt_path"] = os.path.join(checkpoint_dir, default_model_names[hint_key])
# Regardless whether "control_weight_prompt" is provided (i.e. whether we automatically
# generate spatiotemporal control weight binary masks), control_weight is needed to.
if "control_weight" not in config:
log.warning(f"No control weight specified for {hint_key}. Setting to 0.5.")
config["control_weight"] = "0.5"
else:
# Check if control weight is a path or a scalar
weight = config["control_weight"]
if not isinstance(weight, str) or not weight.endswith(".pt"):
try:
# Try converting to float
scalar_value = float(weight)
if scalar_value < 0:
raise ValueError(f"Control weight for {hint_key} must be non-negative.")
except ValueError:
raise ValueError(
f"Control weight for {hint_key} must be a valid non-negative float or a path to a .pt file."
)
return controlnet_specs
def demo(cfg, control_inputs):
"""Run control-to-world generation demo.
This function handles the main control-to-world generation pipeline, including:
- Setting up the random seed for reproducibility
- Initializing the generation pipeline with the provided configuration
- Processing single or multiple prompts/images/videos from input
- Generating videos from prompts and images/videos
- Saving the generated videos and corresponding prompts to disk
Args:
cfg (argparse.Namespace): Configuration namespace containing:
- Model configuration (checkpoint paths, model settings)
- Generation parameters (guidance, steps, dimensions)
- Input/output settings (prompts/images/videos, save paths)
- Performance options (model offloading settings)
The function will save:
- Generated MP4 video files
- Text files containing the processed prompts
If guardrails block the generation, a critical log message is displayed
and the function continues to the next prompt if available.
"""
control_inputs = validate_controlnet_specs(cfg, control_inputs)
misc.set_random_seed(cfg.seed)
device_rank = 0
process_group = None
if cfg.num_gpus > 1:
from megatron.core import parallel_state
from cosmos_transfer1.utils import distributed
distributed.init()
parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus)
process_group = parallel_state.get_context_parallel_group()
device_rank = distributed.get_rank(process_group)
preprocessors = Preprocessors()
if cfg.waymo_example:
prompts = [
cfg.prompt,
cfg.prompt_left,
cfg.prompt_right,
cfg.prompt_back_left,
cfg.prompt_back_right,
]
if cfg.initial_condition_video:
cfg.is_lvg_model = True
checkpoint = SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH
else:
cfg.is_lvg_model = False
cfg.num_input_frames = 0
checkpoint = SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH
else:
prompts = [
cfg.prompt,
cfg.prompt_left,
cfg.prompt_right,
cfg.prompt_back,
cfg.prompt_back_left,
cfg.prompt_back_right,
]
if cfg.initial_condition_video:
cfg.is_lvg_model = True
checkpoint = BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH
else:
cfg.is_lvg_model = False
cfg.num_input_frames = 0
checkpoint = BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH
# Initialize transfer generation model pipeline
pipeline = DiffusionControl2WorldMultiviewGenerationPipeline(
checkpoint_dir=cfg.checkpoint_dir,
checkpoint_name=checkpoint,
offload_network=cfg.offload_diffusion_transformer,
offload_text_encoder_model=cfg.offload_text_encoder_model,
offload_guardrail_models=cfg.offload_guardrail_models,
guidance=cfg.guidance,
num_steps=cfg.num_steps,
fps=cfg.fps,
seed=cfg.seed,
num_input_frames=cfg.num_input_frames,
control_inputs=control_inputs,
sigma_max=80.0,
num_video_frames=57,
process_group=process_group,
height=576,
width=1024,
is_lvg_model=cfg.is_lvg_model,
n_clip_max=cfg.n_clip_max,
waymo_example=cfg.waymo_example,
)
os.makedirs(cfg.video_save_folder, exist_ok=True)
current_prompt = prompts
current_video_path = ""
video_save_subfolder = os.path.join(cfg.video_save_folder, "video_0")
os.makedirs(video_save_subfolder, exist_ok=True)
current_control_inputs = copy.deepcopy(control_inputs)
# if control inputs are not provided, run respective preprocessor (for seg and depth)
preprocessors(current_video_path, current_prompt, current_control_inputs, video_save_subfolder)
# Generate video
generated_output = pipeline.generate(
prompts=current_prompt,
view_condition_video=cfg.view_condition_video,
initial_condition_video=cfg.initial_condition_video,
control_inputs=current_control_inputs,
save_folder=video_save_subfolder,
)
if generated_output is None:
log.critical("Guardrail blocked generation.")
video, prompt = generated_output
video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
if device_rank == 0:
# Save video
os.makedirs(os.path.dirname(video_save_path), exist_ok=True)
save_video(
video=video,
fps=cfg.fps,
H=video.shape[1],
W=video.shape[2],
video_save_quality=7,
video_save_path=video_save_path,
)
# Save prompt to text file alongside video
with open(prompt_save_path, "wb") as f:
f.write(";".join(prompt).encode("utf-8"))
log.info(f"Saved video to {video_save_path}")
log.info(f"Saved prompt to {prompt_save_path}")
# clean up properly
if cfg.num_gpus > 1:
parallel_state.destroy_model_parallel()
import torch.distributed as dist
dist.destroy_process_group()
if __name__ == "__main__":
args, control_inputs = parse_arguments()
demo(args, control_inputs)