|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import copy |
|
import json |
|
import os |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
import sys |
|
from io import BytesIO |
|
|
|
import torch |
|
|
|
from cosmos_transfer1.checkpoints import ( |
|
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, |
|
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, |
|
) |
|
from cosmos_transfer1.diffusion.inference.inference_utils import ( |
|
default_model_names, |
|
load_controlnet_specs, |
|
valid_hint_keys, |
|
) |
|
from cosmos_transfer1.diffusion.inference.preprocessors import Preprocessors |
|
from cosmos_transfer1.diffusion.inference.world_generation_pipeline import ( |
|
DiffusionControl2WorldMultiviewGenerationPipeline, |
|
) |
|
from cosmos_transfer1.utils import log, misc |
|
from cosmos_transfer1.utils.io import save_video |
|
|
|
torch.enable_grad(False) |
|
|
|
from cosmos_transfer1.checkpoints import ( |
|
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH, |
|
BASE_7B_CHECKPOINT_PATH, |
|
DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH, |
|
VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, |
|
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH, |
|
SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH, |
|
SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, |
|
) |
|
from cosmos_transfer1.diffusion.model.model_ctrl import VideoDiffusionModelWithCtrl, VideoDiffusionT2VModelWithCtrl |
|
from cosmos_transfer1.diffusion.model.model_multi_camera_ctrl import MultiVideoDiffusionModelWithCtrl |
|
|
|
MODEL_CLASS_DICT = { |
|
BASE_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, |
|
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: VideoDiffusionT2VModelWithCtrl, |
|
HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionT2VModelWithCtrl, |
|
LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionT2VModelWithCtrl, |
|
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, |
|
} |
|
|
|
MODEL_NAME_DICT = { |
|
BASE_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3", |
|
EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3", |
|
VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_vis_block3", |
|
DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_depth_block3", |
|
KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_keypoint_block3", |
|
SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3", |
|
UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_upscale_block3", |
|
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_hdmap_block3", |
|
HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_hdmap_block3", |
|
LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_lidar_block3", |
|
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_block3", |
|
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_sv2mv_v2w_57frames_control_input_hdmap_block3", |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_block3", |
|
SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_lidar_block3", |
|
SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_waymo_block3", |
|
SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_v2w_57frames_control_input_hdmap_waymo_block3", |
|
} |
|
|
|
|
|
def parse_arguments() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser(description="Control to world generation demo script", conflict_handler="resolve") |
|
parser.add_argument( |
|
"--prompt", |
|
type=str, |
|
default="The video captures a stunning, photorealistic scene with remarkable attention to detail, giving it a lifelike appearance that is almost indistinguishable from reality. It appears to be from a high-budget 4K movie, showcasing ultra-high-definition quality with impeccable resolution.", |
|
help="prompt which the sampled video condition on", |
|
) |
|
parser.add_argument( |
|
"--prompt_left", |
|
type=str, |
|
default="The video is captured from a camera mounted on a car. The camera is facing to the left. ", |
|
help="Text prompt for generating left camera view video", |
|
) |
|
parser.add_argument( |
|
"--prompt_right", |
|
type=str, |
|
default="The video is captured from a camera mounted on a car. The camera is facing to the right.", |
|
help="Text prompt for generating right camera view video", |
|
) |
|
|
|
parser.add_argument( |
|
"--prompt_back", |
|
type=str, |
|
default="The video is captured from a camera mounted on a car. The camera is facing backwards.", |
|
help="Text prompt for generating rear camera view video", |
|
) |
|
parser.add_argument( |
|
"--prompt_back_left", |
|
type=str, |
|
default="The video is captured from a camera mounted on a car. The camera is facing the rear left side.", |
|
help="Text prompt for generating left camera view video", |
|
) |
|
parser.add_argument( |
|
"--prompt_back_right", |
|
type=str, |
|
default="The video is captured from a camera mounted on a car. The camera is facing the rear right side.", |
|
help="Text prompt for generating right camera view video", |
|
) |
|
parser.add_argument( |
|
"--view_condition_video", |
|
type=str, |
|
default="", |
|
help="We require that only a single condition view is specified and this video is treated as conditioning for that view. " |
|
"This video/videos should have the same duration as control videos", |
|
) |
|
parser.add_argument( |
|
"--initial_condition_video", |
|
type=str, |
|
default="", |
|
help="Can be either a path to a mp4 or a directory. If it is a mp4, we assume" |
|
"that it is a video temporally concatenated with the same number of views as the model. " |
|
"If it is a directory, we assume that the file names evaluate to integers that correspond to a view index," |
|
" e.g. '000.mp4', '003.mp4', '004.mp4'." |
|
"This video/videos should have at least num_input_frames number of frames for each view. Frames will be taken from the back" |
|
"of the video(s) if the duration of the video in each view exceed num_input_frames", |
|
) |
|
parser.add_argument( |
|
"--num_input_frames", |
|
type=int, |
|
default=1, |
|
help="Number of conditional frames for long video generation, not used in t2w", |
|
choices=[1, 9], |
|
) |
|
parser.add_argument( |
|
"--controlnet_specs", |
|
type=str, |
|
help="Path to JSON file specifying multicontrolnet configurations", |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" |
|
) |
|
parser.add_argument( |
|
"--tokenizer_dir", |
|
type=str, |
|
default="Cosmos-Tokenize1-CV8x8x8-720p", |
|
help="Tokenizer weights directory relative to checkpoint_dir", |
|
) |
|
parser.add_argument( |
|
"--video_save_name", |
|
type=str, |
|
default="output", |
|
help="Output filename for generating a single video", |
|
) |
|
parser.add_argument( |
|
"--video_save_folder", |
|
type=str, |
|
default="outputs/", |
|
help="Output folder for generating a batch of videos", |
|
) |
|
parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps") |
|
parser.add_argument("--guidance", type=float, default=5, help="Classifier-free guidance scale value") |
|
parser.add_argument("--fps", type=int, default=24, help="FPS of the output video") |
|
parser.add_argument("--seed", type=int, default=1, help="Random seed") |
|
parser.add_argument("--n_clip_max", type=int, default=-1, help="Maximum number of video extension loop") |
|
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs used to run inference in parallel.") |
|
parser.add_argument( |
|
"--offload_diffusion_transformer", |
|
action="store_true", |
|
help="Offload DiT after inference", |
|
) |
|
parser.add_argument( |
|
"--offload_text_encoder_model", |
|
action="store_true", |
|
help="Offload text encoder model after inference", |
|
) |
|
parser.add_argument( |
|
"--offload_guardrail_models", |
|
action="store_true", |
|
help="Offload guardrail models after inference", |
|
) |
|
parser.add_argument( |
|
"--upsample_prompt", |
|
action="store_true", |
|
help="Upsample prompt using Pixtral upsampler model", |
|
) |
|
parser.add_argument( |
|
"--offload_prompt_upsampler", |
|
action="store_true", |
|
help="Offload prompt upsampler model after inference", |
|
) |
|
parser.add_argument( |
|
"--waymo_example", |
|
type=bool, |
|
default=False, |
|
help="Set to true when using post-trained checkpoint from the Waymo post-training example", |
|
) |
|
|
|
cmd_args = parser.parse_args() |
|
|
|
|
|
control_inputs, json_args = load_controlnet_specs(cmd_args) |
|
control_inputs.update(json_args) |
|
log.info(f"control_inputs: {json.dumps(control_inputs, indent=4)}") |
|
log.info(f"args in json: {json.dumps(json_args, indent=4)}") |
|
|
|
|
|
|
|
for key in json_args: |
|
if f"--{key}" not in sys.argv: |
|
setattr(cmd_args, key, json_args[key]) |
|
|
|
log.info(f"final args: {json.dumps(vars(cmd_args), indent=4)}") |
|
|
|
return cmd_args, control_inputs |
|
|
|
|
|
def validate_controlnet_specs(cfg, controlnet_specs): |
|
""" |
|
Load and validate controlnet specifications from a JSON file. |
|
|
|
Args: |
|
json_path (str): Path to the JSON file containing controlnet specs. |
|
checkpoint_dir (str): Base directory for checkpoint files. |
|
|
|
Returns: |
|
Dict[str, Any]: Validated and processed controlnet specifications. |
|
""" |
|
checkpoint_dir = cfg.checkpoint_dir |
|
|
|
for hint_key, config in controlnet_specs.items(): |
|
if hint_key not in list(valid_hint_keys) + ["prompts", "view_condition_video"]: |
|
raise ValueError(f"Invalid hint_key: {hint_key}. Must be one of {valid_hint_keys}") |
|
if hint_key in valid_hint_keys: |
|
if "ckpt_path" not in config: |
|
log.info(f"No checkpoint path specified for {hint_key}. Using default.") |
|
config["ckpt_path"] = os.path.join(checkpoint_dir, default_model_names[hint_key]) |
|
|
|
|
|
|
|
if "control_weight" not in config: |
|
log.warning(f"No control weight specified for {hint_key}. Setting to 0.5.") |
|
config["control_weight"] = "0.5" |
|
else: |
|
|
|
weight = config["control_weight"] |
|
if not isinstance(weight, str) or not weight.endswith(".pt"): |
|
try: |
|
|
|
scalar_value = float(weight) |
|
if scalar_value < 0: |
|
raise ValueError(f"Control weight for {hint_key} must be non-negative.") |
|
except ValueError: |
|
raise ValueError( |
|
f"Control weight for {hint_key} must be a valid non-negative float or a path to a .pt file." |
|
) |
|
|
|
return controlnet_specs |
|
|
|
|
|
def demo(cfg, control_inputs): |
|
"""Run control-to-world generation demo. |
|
|
|
This function handles the main control-to-world generation pipeline, including: |
|
- Setting up the random seed for reproducibility |
|
- Initializing the generation pipeline with the provided configuration |
|
- Processing single or multiple prompts/images/videos from input |
|
- Generating videos from prompts and images/videos |
|
- Saving the generated videos and corresponding prompts to disk |
|
|
|
Args: |
|
cfg (argparse.Namespace): Configuration namespace containing: |
|
- Model configuration (checkpoint paths, model settings) |
|
- Generation parameters (guidance, steps, dimensions) |
|
- Input/output settings (prompts/images/videos, save paths) |
|
- Performance options (model offloading settings) |
|
|
|
The function will save: |
|
- Generated MP4 video files |
|
- Text files containing the processed prompts |
|
|
|
If guardrails block the generation, a critical log message is displayed |
|
and the function continues to the next prompt if available. |
|
""" |
|
|
|
control_inputs = validate_controlnet_specs(cfg, control_inputs) |
|
misc.set_random_seed(cfg.seed) |
|
|
|
device_rank = 0 |
|
process_group = None |
|
if cfg.num_gpus > 1: |
|
from megatron.core import parallel_state |
|
|
|
from cosmos_transfer1.utils import distributed |
|
|
|
distributed.init() |
|
parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus) |
|
process_group = parallel_state.get_context_parallel_group() |
|
|
|
device_rank = distributed.get_rank(process_group) |
|
|
|
preprocessors = Preprocessors() |
|
|
|
if cfg.waymo_example: |
|
prompts = [ |
|
cfg.prompt, |
|
cfg.prompt_left, |
|
cfg.prompt_right, |
|
cfg.prompt_back_left, |
|
cfg.prompt_back_right, |
|
] |
|
if cfg.initial_condition_video: |
|
cfg.is_lvg_model = True |
|
checkpoint = SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH |
|
else: |
|
cfg.is_lvg_model = False |
|
cfg.num_input_frames = 0 |
|
checkpoint = SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH |
|
|
|
else: |
|
prompts = [ |
|
cfg.prompt, |
|
cfg.prompt_left, |
|
cfg.prompt_right, |
|
cfg.prompt_back, |
|
cfg.prompt_back_left, |
|
cfg.prompt_back_right, |
|
] |
|
|
|
if cfg.initial_condition_video: |
|
cfg.is_lvg_model = True |
|
checkpoint = BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH |
|
else: |
|
cfg.is_lvg_model = False |
|
cfg.num_input_frames = 0 |
|
checkpoint = BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH |
|
|
|
|
|
pipeline = DiffusionControl2WorldMultiviewGenerationPipeline( |
|
checkpoint_dir=cfg.checkpoint_dir, |
|
checkpoint_name=checkpoint, |
|
offload_network=cfg.offload_diffusion_transformer, |
|
offload_text_encoder_model=cfg.offload_text_encoder_model, |
|
offload_guardrail_models=cfg.offload_guardrail_models, |
|
guidance=cfg.guidance, |
|
num_steps=cfg.num_steps, |
|
fps=cfg.fps, |
|
seed=cfg.seed, |
|
num_input_frames=cfg.num_input_frames, |
|
control_inputs=control_inputs, |
|
sigma_max=80.0, |
|
num_video_frames=57, |
|
process_group=process_group, |
|
height=576, |
|
width=1024, |
|
is_lvg_model=cfg.is_lvg_model, |
|
n_clip_max=cfg.n_clip_max, |
|
waymo_example=cfg.waymo_example, |
|
) |
|
|
|
os.makedirs(cfg.video_save_folder, exist_ok=True) |
|
|
|
current_prompt = prompts |
|
current_video_path = "" |
|
video_save_subfolder = os.path.join(cfg.video_save_folder, "video_0") |
|
os.makedirs(video_save_subfolder, exist_ok=True) |
|
current_control_inputs = copy.deepcopy(control_inputs) |
|
|
|
|
|
preprocessors(current_video_path, current_prompt, current_control_inputs, video_save_subfolder) |
|
|
|
|
|
generated_output = pipeline.generate( |
|
prompts=current_prompt, |
|
view_condition_video=cfg.view_condition_video, |
|
initial_condition_video=cfg.initial_condition_video, |
|
control_inputs=current_control_inputs, |
|
save_folder=video_save_subfolder, |
|
) |
|
if generated_output is None: |
|
log.critical("Guardrail blocked generation.") |
|
video, prompt = generated_output |
|
|
|
video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4") |
|
prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt") |
|
|
|
if device_rank == 0: |
|
|
|
os.makedirs(os.path.dirname(video_save_path), exist_ok=True) |
|
save_video( |
|
video=video, |
|
fps=cfg.fps, |
|
H=video.shape[1], |
|
W=video.shape[2], |
|
video_save_quality=7, |
|
video_save_path=video_save_path, |
|
) |
|
|
|
|
|
with open(prompt_save_path, "wb") as f: |
|
f.write(";".join(prompt).encode("utf-8")) |
|
|
|
log.info(f"Saved video to {video_save_path}") |
|
log.info(f"Saved prompt to {prompt_save_path}") |
|
|
|
|
|
if cfg.num_gpus > 1: |
|
parallel_state.destroy_model_parallel() |
|
import torch.distributed as dist |
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
if __name__ == "__main__": |
|
args, control_inputs = parse_arguments() |
|
demo(args, control_inputs) |
|
|