cosmos_transfer1_av / cosmos_transfer1 /diffusion /config /inference /cosmos-1-diffusion-control2world.py
harry900000's picture
add cosmos-tranfer1/ into repo
226c7c9
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hydra.core.config_store import ConfigStore
from cosmos_transfer1.checkpoints import (
BASE_7B_CHECKPOINT_AV_SAMPLE_PATH,
EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH,
BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH,
BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH,
)
from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS_COMB
from cosmos_transfer1.diffusion.model.model_ctrl import (
VideoDiffusionModelWithCtrl,
VideoDiffusionT2VModelWithCtrl,
VideoDistillModelWithCtrl,
)
from cosmos_transfer1.diffusion.model.model_multi_camera_ctrl import MultiVideoDiffusionModelWithCtrl
from cosmos_transfer1.diffusion.networks.general_dit_multi_view import MultiViewVideoExtendGeneralDIT
from cosmos_transfer1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT
from cosmos_transfer1.utils.lazy_config import LazyCall as L
from cosmos_transfer1.utils.lazy_config import LazyDict
cs = ConfigStore.instance()
# Base configuration for 7B model
Base_7B_Config = LazyDict(
dict(
defaults=[
{"override /net": "faditv2_7b"},
{"override /conditioner": "add_fps_image_size_padding_mask"},
{"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"},
"_self_",
],
model=dict(
latent_shape=[16, 16, 88, 160],
net=dict(
rope_h_extrapolation_ratio=1,
rope_w_extrapolation_ratio=1,
rope_t_extrapolation_ratio=2,
),
),
job=dict(
group="Control2World",
name="Base_7B_Config",
),
)
)
def make_ctrlnet_config_7b(
hint_key: str = "control_input_seg",
num_control_blocks: int = 3,
) -> LazyDict:
hint_mask = [True] * len(CTRL_HINT_KEYS_COMB[hint_key])
return LazyDict(
dict(
defaults=[
"/experiment/Base_7B_Config",
{"override /hint_key": hint_key},
{"override /net_ctrl": "faditv2_7b"},
{"override /conditioner": "ctrlnet_add_fps_image_size_padding_mask"},
],
job=dict(
group="CTRL_7Bv1_lvg",
name=f"CTRL_7Bv1pt3_lvg_tp_121frames_{hint_key}_block{num_control_blocks}",
project="cosmos_transfer1",
),
model=dict(
hint_mask=hint_mask,
hint_dropout_rate=0.3,
conditioner=dict(video_cond_bool=dict()),
net=L(VideoExtendGeneralDIT)(
extra_per_block_abs_pos_emb=True,
pos_emb_learnable=True,
extra_per_block_abs_pos_emb_type="learnable",
),
net_ctrl=dict(
in_channels=17,
hint_channels=128,
num_blocks=28,
layer_mask=[True if (i >= num_control_blocks) else False for i in range(28)],
extra_per_block_abs_pos_emb=True,
pos_emb_learnable=True,
extra_per_block_abs_pos_emb_type="learnable",
),
),
model_obj=L(VideoDiffusionModelWithCtrl)(),
)
)
def make_ctrlnet_config_7b_t2v(
hint_key: str = "control_input_seg",
num_control_blocks: int = 3,
) -> LazyDict:
hint_mask = [True] * len(CTRL_HINT_KEYS_COMB[hint_key])
return LazyDict(
dict(
defaults=[
"/experiment/Base_7B_Config",
{"override /hint_key": hint_key},
{"override /net_ctrl": "faditv2_7b"},
{"override /conditioner": "ctrlnet_add_fps_image_size_padding_mask"},
],
job=dict(
group="CTRL_7Bv1_t2v",
name=f"CTRL_7Bv1pt3_t2v_121frames_{hint_key}_block{num_control_blocks}",
project="cosmos_ctrlnet1",
),
model=dict(
base_load_from=dict(
load_path=f"checkpoints/{BASE_7B_CHECKPOINT_AV_SAMPLE_PATH}",
),
hint_mask=hint_mask,
hint_dropout_rate=0.3,
net=dict(
extra_per_block_abs_pos_emb=True,
pos_emb_learnable=True,
extra_per_block_abs_pos_emb_type="learnable",
),
net_ctrl=dict(
in_channels=16,
hint_channels=16,
num_blocks=28,
layer_mask=[True if (i >= num_control_blocks) else False for i in range(28)],
extra_per_block_abs_pos_emb=True,
pos_emb_learnable=True,
extra_per_block_abs_pos_emb_type="learnable",
),
),
model_obj=L(VideoDiffusionT2VModelWithCtrl)(),
)
)
def make_ctrlnet_config_7b_mv(
hint_key: str = "control_input_seg",
num_control_blocks: int = 3,
t2w: bool = True,
) -> LazyDict:
hint_mask = [True] * len(CTRL_HINT_KEYS_COMB[hint_key])
return LazyDict(
dict(
defaults=[
"/experiment/Base_7B_Config",
{"override /hint_key": hint_key},
{"override /net_ctrl": "faditv2_7b_mv"},
{"override /conditioner": "view_cond_ctrlnet_add_fps_image_size_padding_mask"},
],
job=dict(
group="CTRL_7Bv1_mv",
name=f"CTRL_7Bv1pt3_sv2mv_{'t2w' if t2w else 'v2w'}_57frames_{hint_key}_block{num_control_blocks}",
project="cosmos_ctrlnet1",
),
model=dict(
n_views=6,
base_load_from=dict(
load_path=f"checkpoints/{BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH}"
if t2w
else f"checkpoints/{BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH}",
),
hint_mask=hint_mask,
hint_dropout_rate=0.3,
conditioner=dict(
video_cond_bool=dict(
condition_location="first_cam" if t2w else "first_cam_and_first_n",
)
),
net=L(MultiViewVideoExtendGeneralDIT)(
n_views=6,
n_views_emb=7,
camera_condition_dim=6,
add_repeat_frame_embedding=True,
extra_per_block_abs_pos_emb=True,
pos_emb_learnable=True,
),
net_ctrl=dict(
in_channels=16,
hint_channels=16,
num_blocks=28,
n_views=6,
n_views_emb=7,
camera_condition_dim=6,
add_repeat_frame_embedding=True,
is_extend_model=True,
layer_mask=[True if (i >= num_control_blocks) else False for i in range(28)],
extra_per_block_abs_pos_emb=True,
pos_emb_learnable=True,
),
tokenizer=dict(
video_vae=dict(
pixel_chunk_duration=57,
)
),
),
model_obj=L(MultiVideoDiffusionModelWithCtrl)(),
)
)
def make_ctrlnet_config_7b_mv_waymo(
hint_key: str = "control_input_seg",
num_control_blocks: int = 3,
t2w: bool = True,
) -> LazyDict:
hint_mask = [True] * len(CTRL_HINT_KEYS_COMB[hint_key])
return LazyDict(
dict(
defaults=[
"/experiment/Base_7B_Config",
{"override /hint_key": hint_key},
{"override /net_ctrl": "faditv2_7b_mv"},
{"override /conditioner": "view_cond_ctrlnet_add_fps_image_size_padding_mask"},
],
job=dict(
group="CTRL_7Bv1_mv",
name=f"CTRL_7Bv1pt3_sv2mv_{'t2w' if t2w else 'v2w'}_57frames_{hint_key}_waymo_block{num_control_blocks}",
project="cosmos_ctrlnet1",
),
model=dict(
n_views=5,
base_load_from=dict(
load_path=f"checkpoints/{BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH}"
if t2w
else f"checkpoints/{BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH}",
),
hint_mask=hint_mask,
hint_dropout_rate=0.15,
conditioner=dict(
video_cond_bool=dict(
condition_location="first_cam" if t2w else "first_cam_and_first_n",
cfg_unconditional_type="zero_condition_region_condition_mask",
apply_corruption_to_condition_region="noise_with_sigma",
condition_on_augment_sigma=False,
dropout_rate=0.0,
first_random_n_num_condition_t_max=0 if t2w else 2,
normalize_condition_latent=False,
augment_sigma_sample_p_mean=-3.0,
augment_sigma_sample_p_std=2.0,
augment_sigma_sample_multiplier=1.0,
)
),
net=L(MultiViewVideoExtendGeneralDIT)(
n_views=5,
n_views_emb=7,
camera_condition_dim=6,
add_repeat_frame_embedding=True,
extra_per_block_abs_pos_emb=True,
pos_emb_learnable=True,
extra_per_block_abs_pos_emb_type="learnable",
num_blocks=28,
),
adjust_video_noise=True,
net_ctrl=dict(
in_channels=16,
hint_channels=16,
num_blocks=28,
n_views=5,
n_views_emb=7,
camera_condition_dim=6,
add_repeat_frame_embedding=True,
is_extend_model=True,
layer_mask=[True if (i >= num_control_blocks) else False for i in range(28)],
extra_per_block_abs_pos_emb=True,
pos_emb_learnable=True,
extra_per_block_abs_pos_emb_type="learnable",
),
tokenizer=dict(
video_vae=dict(
pixel_chunk_duration=57,
)
),
),
model_obj=L(MultiVideoDiffusionModelWithCtrl)(),
)
)
# Register base configs
cs.store(group="experiment", package="_global_", name=Base_7B_Config["job"]["name"], node=Base_7B_Config)
# Register all control configurations
num_control_blocks = 3
for key in CTRL_HINT_KEYS_COMB.keys():
# Register 7B configurations
config_7b = make_ctrlnet_config_7b(hint_key=key, num_control_blocks=num_control_blocks)
cs.store(group="experiment", package="_global_", name=config_7b["job"]["name"], node=config_7b)
# Register t2v based control net
num_control_blocks = 3
for key in ["control_input_hdmap", "control_input_lidar"]:
# Register 7B configurations
config_7b = make_ctrlnet_config_7b_t2v(hint_key=key, num_control_blocks=num_control_blocks)
cs.store(group="experiment", package="_global_", name=config_7b["job"]["name"], node=config_7b)
num_control_blocks = 3
for key in ["control_input_hdmap", "control_input_lidar"]:
for t2w in [True, False]:
# Register 7B sv2mv configurations
config_7b = make_ctrlnet_config_7b_mv(hint_key=key, num_control_blocks=num_control_blocks, t2w=t2w)
cs.store(group="experiment", package="_global_", name=config_7b["job"]["name"], node=config_7b)
# Register waymo example
num_control_blocks = 3
for key in ["control_input_hdmap", "control_input_lidar"]:
for t2w in [True, False]:
# Register 7B sv2mv configurations
config_7b = make_ctrlnet_config_7b_mv_waymo(hint_key=key, num_control_blocks=num_control_blocks, t2w=t2w)
cs.store(group="experiment", package="_global_", name=config_7b["job"]["name"], node=config_7b)
def make_ctrlnet_config_7b_distilled(
hint_key: str = "control_input_edge",
num_control_blocks: int = 3,
) -> LazyDict:
hint_mask = [True] * len(CTRL_HINT_KEYS_COMB[hint_key])
return LazyDict(
dict(
defaults=[
"/experiment/Base_7B_Config",
{"override /hint_key": hint_key},
{"override /net": "faditv2_7b"},
{"override /net_ctrl": "faditv2_7b"},
{"override /conditioner": "ctrlnet_add_fps_image_size_padding_mask"},
"_self_",
],
job=dict(
group="DISTILL_CTRL_7Bv1",
name=f"CTRL_7Bv1pt3_lvg_fsdp_distilled_121frames_{hint_key}_block{num_control_blocks}",
project="cosmos_nano_v1",
),
model=dict(
base_load_from=dict(
load_path=f"checkpoints/{EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH}",
),
hint_mask=hint_mask,
hint_dropout_rate=0.0,
conditioner=dict(
video_cond_bool=dict(
condition_location="first_random_n",
cfg_unconditional_type="zero_condition_region_condition_mask",
apply_corruption_to_condition_region="noise_with_sigma_fixed",
condition_on_augment_sigma=False,
dropout_rate=0.0,
first_random_n_num_condition_t_max=2,
first_random_n_num_condition_t_min=0,
normalize_condition_latent=False,
augment_sigma_sample_p_mean=-3.0,
augment_sigma_sample_p_std=2.0,
augment_sigma_sample_multiplier=1.0,
)
),
net=L(VideoExtendGeneralDIT)(
extra_per_block_abs_pos_emb=True,
pos_emb_learnable=True,
extra_per_block_abs_pos_emb_type="learnable",
rope_t_extrapolation_ratio=2,
),
net_ctrl=dict(
in_channels=17,
hint_channels=128,
num_blocks=28,
layer_mask=[True if (i >= num_control_blocks) else False for i in range(28)],
num_control_blocks=num_control_blocks,
dropout_ctrl_branch=0,
extra_per_block_abs_pos_emb=True,
pos_emb_learnable=True,
extra_per_block_abs_pos_emb_type="learnable",
),
),
model_obj=L(VideoDistillModelWithCtrl)(),
)
)
# Register the specific distilled configuration
distilled_config = make_ctrlnet_config_7b_distilled(hint_key="control_input_edge", num_control_blocks=3)
cs.store(
group="experiment",
package="_global_",
name="dev_v2w_ctrl_7bv1pt3_VisControlCanny_video_only_dmd2_fsdp",
node=distilled_config,
)