|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Registry for training experiments, callbacks and data. |
|
""" |
|
|
|
import copy |
|
|
|
from hydra.core.config_store import ConfigStore |
|
|
|
import cosmos_transfer1.diffusion.config.training.registry as base_training_registry |
|
from cosmos_transfer1.diffusion.config.base.data import register_data_ctrlnet |
|
from cosmos_transfer1.diffusion.config.registry import register_conditioner |
|
from cosmos_transfer1.diffusion.config.training.tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 |
|
from cosmos_transfer1.diffusion.config.transfer.conditioner import ( |
|
CTRL_HINT_KEYS, |
|
BaseVideoConditionerWithCtrlConfig, |
|
VideoConditionerFpsSizePaddingWithCtrlConfig, |
|
ViewConditionedVideoConditionerFpsSizePaddingWithCtrlConfig, |
|
) |
|
from cosmos_transfer1.diffusion.training.networks.general_dit import GeneralDIT |
|
from cosmos_transfer1.diffusion.training.networks.general_dit_ctrl_enc import GeneralDITEncoder |
|
from cosmos_transfer1.diffusion.training.networks.general_dit_ctrl_enc_multicamera import GeneralDITMulticamEncoder |
|
|
|
|
|
from cosmos_transfer1.utils.lazy_config import LazyCall as L |
|
from cosmos_transfer1.utils.lazy_config import LazyDict |
|
|
|
FADITV2ConfigTrain: LazyDict = L(GeneralDIT)( |
|
max_img_h=240, |
|
max_img_w=240, |
|
max_frames=128, |
|
in_channels=16, |
|
out_channels=16, |
|
patch_spatial=2, |
|
patch_temporal=1, |
|
model_channels=4096, |
|
block_config="FA-CA-MLP", |
|
num_blocks=28, |
|
num_heads=32, |
|
concat_padding_mask=True, |
|
pos_emb_cls="rope3d", |
|
pos_emb_learnable=False, |
|
pos_emb_interpolation="crop", |
|
block_x_format="THWBD", |
|
additional_timestamp_channels=None, |
|
affline_emb_norm=True, |
|
use_adaln_lora=True, |
|
adaln_lora_dim=256, |
|
legacy_patch_emb=False, |
|
) |
|
|
|
num_blocks = FADITV2ConfigTrain["num_blocks"] |
|
FADITV2EncoderConfigTrain = copy.deepcopy(FADITV2ConfigTrain) |
|
FADITV2EncoderConfigTrain["_target_"] = GeneralDITEncoder |
|
FADITV2EncoderConfigTrain["layer_mask"] = [True if i > num_blocks // 2 else False for i in range(num_blocks)] |
|
|
|
num_blocks = FADITV2ConfigTrain["num_blocks"] |
|
FADITV2MultiCamEncoderConfig = copy.deepcopy(FADITV2ConfigTrain) |
|
FADITV2MultiCamEncoderConfig["_target_"] = GeneralDITMulticamEncoder |
|
FADITV2MultiCamEncoderConfig["layer_mask"] = [True if i > num_blocks // 2 else False for i in range(num_blocks)] |
|
|
|
|
|
def register_net_train(cs): |
|
cs.store( |
|
group="net", |
|
package="model.net", |
|
name="faditv2_7b", |
|
node=FADITV2ConfigTrain, |
|
) |
|
cs.store(group="net_ctrl", package="model.net_ctrl", name="faditv2_7b", node=FADITV2EncoderConfigTrain) |
|
cs.store(group="net_ctrl", package="model.net_ctrl", name="faditv2_sv2mv", node=FADITV2MultiCamEncoderConfig) |
|
|
|
|
|
def register_conditioner_ctrlnet(cs): |
|
cs.store( |
|
group="conditioner", |
|
package="model.conditioner", |
|
name="ctrlnet", |
|
node=BaseVideoConditionerWithCtrlConfig, |
|
) |
|
cs.store( |
|
group="conditioner", |
|
package="model.conditioner", |
|
name="ctrlnet_add_fps_image_size_padding_mask", |
|
node=VideoConditionerFpsSizePaddingWithCtrlConfig, |
|
) |
|
cs.store( |
|
group="conditioner", |
|
package="model.conditioner", |
|
name="view_cond_ctrlnet_add_fps_image_size_padding_mask", |
|
node=ViewConditionedVideoConditionerFpsSizePaddingWithCtrlConfig, |
|
) |
|
|
|
|
|
def register_tokenizer(cs): |
|
cs.store( |
|
group="tokenizer", |
|
package="model.tokenizer", |
|
name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624", |
|
node=get_cosmos_diffusion_tokenizer_comp8x8x8(resolution="720", chunk_duration=121), |
|
) |
|
|
|
|
|
def register_configs(): |
|
cs = ConfigStore.instance() |
|
|
|
|
|
register_net_train(cs) |
|
register_conditioner(cs) |
|
register_conditioner_ctrlnet(cs) |
|
register_tokenizer(cs) |
|
|
|
|
|
base_training_registry.register_configs() |
|
|
|
|
|
register_data_ctrlnet(cs) |
|
|
|
|
|
for hint_key in CTRL_HINT_KEYS: |
|
cs.store( |
|
group="hint_key", |
|
package="model", |
|
name=hint_key, |
|
node=dict(hint_key=dict(hint_key=hint_key, grayscale=False)), |
|
) |
|
cs.store( |
|
group="hint_key", |
|
package="model", |
|
name=f"{hint_key}_grayscale", |
|
node=dict(hint_key=dict(hint_key=hint_key, grayscale=True)), |
|
) |
|
|