diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..d4c176a2cb38584cbcc5ac6c2cfe90ec85c4a396 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/demo_1.gif filter=lfs diff=lfs merge=lfs -text +assets/demo_2.gif filter=lfs diff=lfs merge=lfs -text +assets/demo_3.gif filter=lfs diff=lfs merge=lfs -text +assets/demo_dynamic.gif filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000000.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000001.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000002.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000003.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000004.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000005.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000006.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000007.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000008.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000009.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000010.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000011.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000012.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000013.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000014.png filter=lfs diff=lfs merge=lfs -text +assets/diffusion/000015.png filter=lfs diff=lfs merge=lfs -text +cosmos_predict1/tokenizer/test_data/image.png filter=lfs diff=lfs merge=lfs -text +cosmos_predict1/tokenizer/test_data/video.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/assets/demo_1.gif b/assets/demo_1.gif new file mode 100644 index 0000000000000000000000000000000000000000..9a54e100aefeb6ed547109385025c22d0322ad29 --- /dev/null +++ b/assets/demo_1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6162366c56277d084b05a37c617e2994ba75285d421e203556dcff08128b32b +size 14678966 diff --git a/assets/demo_2.gif b/assets/demo_2.gif new file mode 100644 index 0000000000000000000000000000000000000000..54715e1e65fb6428bded8fe88526f33f22608a62 --- /dev/null +++ b/assets/demo_2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e765e71d3016c6e314b6403f82313a1df42f68f6fb0f9416f197d82e0710f27e +size 10573280 diff --git a/assets/demo_3.gif b/assets/demo_3.gif new file mode 100644 index 0000000000000000000000000000000000000000..36cd315a9e756bdd237b6924ff7e0e671bf3d406 --- /dev/null +++ b/assets/demo_3.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c4cf4a4bf62daf03b25ac66c2c3693adbf7cd459e55d3481a65a9ff4a9d09d9 +size 35276047 diff --git a/assets/demo_dynamic.gif b/assets/demo_dynamic.gif new file mode 100644 index 0000000000000000000000000000000000000000..f96dde75172a618a3c2b2aacd4a276e43b1f4185 --- /dev/null +++ b/assets/demo_dynamic.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:174faba45ae701eaa432dd14de1297c0479b6c0b832adbc211cbb529fbec6c61 +size 24517788 diff --git a/assets/diffusion/000000.png b/assets/diffusion/000000.png new file mode 100644 index 0000000000000000000000000000000000000000..7d531d6587b9cb68cb9a77d5be1ad709027c025b --- /dev/null +++ b/assets/diffusion/000000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7e6eab7548c2ede900f8b504a5cef981e0cd0ec38af90dbea3f0db860e002c3 +size 1326071 diff --git a/assets/diffusion/000001.png b/assets/diffusion/000001.png new file mode 100644 index 0000000000000000000000000000000000000000..d754ec6803d60186de066118da8406dad11af7ef --- /dev/null +++ b/assets/diffusion/000001.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:abe310078829c9e1375ac30c7c270c84c8f68a09f3857bd35c7a5754f3326151 +size 1131209 diff --git a/assets/diffusion/000002.png b/assets/diffusion/000002.png new file mode 100644 index 0000000000000000000000000000000000000000..1f3f5f0279e10e718a3478d795db735cbece9d5f --- /dev/null +++ b/assets/diffusion/000002.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ad89b53e9fafed0d8eefd1cfc7cc4889c5d2f510ed32d5247c5adab4cb0c622 +size 789185 diff --git a/assets/diffusion/000003.png b/assets/diffusion/000003.png new file mode 100644 index 0000000000000000000000000000000000000000..e2999ff690a749007b70b5e5a25ee3a21c04ff35 --- /dev/null +++ b/assets/diffusion/000003.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22f39915f1b277e70683befbc18ac5859c65c3d389e4dbb5127a539a411fec54 +size 1105958 diff --git a/assets/diffusion/000004.png b/assets/diffusion/000004.png new file mode 100644 index 0000000000000000000000000000000000000000..20f4fb80c925e51e3c31a597107b3636ea9851c6 --- /dev/null +++ b/assets/diffusion/000004.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2f957208849c0f86b89545734bb7b243868b574554cb6aeed248b04e7234ad4 +size 1262412 diff --git a/assets/diffusion/000005.png b/assets/diffusion/000005.png new file mode 100644 index 0000000000000000000000000000000000000000..0aa49c43c45cdb634da5d424fb8c882be31cb354 --- /dev/null +++ b/assets/diffusion/000005.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:267f6ae47d0e2aebda89fac5416bc0915855043131d0d8d8a4fc9506cabd4681 +size 1364198 diff --git a/assets/diffusion/000006.png b/assets/diffusion/000006.png new file mode 100644 index 0000000000000000000000000000000000000000..668af465ce603e33788278b460dfda72ed308b1b --- /dev/null +++ b/assets/diffusion/000006.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b6fd098366bcd54bd21a5707ae6d9f78d74c2eefcfbb6919569c0d1741d837f +size 1207409 diff --git a/assets/diffusion/000007.png b/assets/diffusion/000007.png new file mode 100644 index 0000000000000000000000000000000000000000..ac9a6a0a297bcebceeea924b5db0255d167ef141 --- /dev/null +++ b/assets/diffusion/000007.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:334733b7428f9521e625a8b310770fbba3e4616ccbe0af625d07e2b065e6e9ad +size 1150728 diff --git a/assets/diffusion/000008.png b/assets/diffusion/000008.png new file mode 100644 index 0000000000000000000000000000000000000000..677f6afcb6963858a98ebb2070e220bb19ad41af --- /dev/null +++ b/assets/diffusion/000008.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7eae1abb3343c1e11f4e42172eba85eeed0fb2a5f7701a42e5003cf84f1696cd +size 1684291 diff --git a/assets/diffusion/000009.png b/assets/diffusion/000009.png new file mode 100644 index 0000000000000000000000000000000000000000..e19b55a92abc9c737fb23a6b493f5a22cfd38e0a --- /dev/null +++ b/assets/diffusion/000009.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a5c5711d41f56bb307ef6020d0dffec9ce2297bda9ef9ae465237d8347adb34 +size 603167 diff --git a/assets/diffusion/000010.png b/assets/diffusion/000010.png new file mode 100644 index 0000000000000000000000000000000000000000..341aad51799c0111c65c58b4bb0e07209e0be04a --- /dev/null +++ b/assets/diffusion/000010.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e4d32f1d1c6d427e421d6f4478d4c2c697cb0406a18ecc3b8ebeeb2a0cbba7f5 +size 1184019 diff --git a/assets/diffusion/000011.png b/assets/diffusion/000011.png new file mode 100644 index 0000000000000000000000000000000000000000..72d11ac239d063aa53298ec1040fa2f27c7735a7 --- /dev/null +++ b/assets/diffusion/000011.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e352d7435d3b313fcc47efd9bd0dc6e0dd5d5e8af8c50e965c57987bee1c94ec +size 944420 diff --git a/assets/diffusion/000012.png b/assets/diffusion/000012.png new file mode 100644 index 0000000000000000000000000000000000000000..c685fc6bfe8c6730b007ddb762bffd3c51962a70 --- /dev/null +++ b/assets/diffusion/000012.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b672d43521890b2852976a0c12828ad16b9288277efff6c41189dc0c04c9c6e1 +size 1098037 diff --git a/assets/diffusion/000013.png b/assets/diffusion/000013.png new file mode 100644 index 0000000000000000000000000000000000000000..6fd722831d73a54f25f9dd20014e91f72aee68d6 --- /dev/null +++ b/assets/diffusion/000013.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eab3a655213eede094889bab94313e1cef142b811429bee9e0f3420c2b013105 +size 1243979 diff --git a/assets/diffusion/000014.png b/assets/diffusion/000014.png new file mode 100644 index 0000000000000000000000000000000000000000..432386657cc4c969a1fc052ce7c1e3d2109beee8 --- /dev/null +++ b/assets/diffusion/000014.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb014db53082677aca35a3fc27daa1f306452c5cb7130a4ed6468cae144a0b63 +size 1351667 diff --git a/assets/diffusion/000015.png b/assets/diffusion/000015.png new file mode 100644 index 0000000000000000000000000000000000000000..2c76996a58c78c95bec945bb0f0c11777bad0989 --- /dev/null +++ b/assets/diffusion/000015.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6ac0d4e7eb6d4dbc3ae997fafc28721b716db092aaa52ede11e4d87b3e9b20d +size 1494431 diff --git a/checkpoints/.DS_Store b/checkpoints/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..280fb22bf245a0e54bad8bd970c7d3ba5f4e35d4 Binary files /dev/null and b/checkpoints/.DS_Store differ diff --git a/checkpoints/README.md b/checkpoints/README.md new file mode 100644 index 0000000000000000000000000000000000000000..726899abdbae8de94885b0c5bc111291fb8dce7a --- /dev/null +++ b/checkpoints/README.md @@ -0,0 +1,4 @@ + +### Checkpoint directory + +Model checkpoints will be downloaded to this directory. diff --git a/cosmos-predict1.yaml b/cosmos-predict1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0722589d77b183cdd7b865227b2e0cb934e27088 --- /dev/null +++ b/cosmos-predict1.yaml @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# conda env create --file cosmos-predict1.yaml +name: cosmos-predict1 +channels: + - conda-forge +dependencies: + - python=3.10 + - pip=25.0 + - cmake + - ninja + - gcc=12.4.0 + - gxx=12.4.0 + - cuda=12.4 + - cuda-nvcc=12.4 + - cuda-toolkit=12.4 diff --git a/cosmos_predict1/.DS_Store b/cosmos_predict1/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..1a2fd51f19bd3da74036942587155d0663379b8d Binary files /dev/null and b/cosmos_predict1/.DS_Store differ diff --git a/cosmos_predict1/__init__.py b/cosmos_predict1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dac9a4d7496eb38831f1f3c820a90d50e25e2a7e --- /dev/null +++ b/cosmos_predict1/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/__init__.py b/cosmos_predict1/autoregressive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py b/cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py new file mode 100644 index 0000000000000000000000000000000000000000..f0df8f71dfd5d79142685210de71fd2c45e87f5c --- /dev/null +++ b/cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py @@ -0,0 +1,352 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import math +import os +from typing import Optional + +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional as torchvision_F +import wandb +from einops import rearrange +from megatron.core import parallel_state +from torch.distributed import get_process_group_ranks + +from cosmos_predict1.autoregressive.utils.parallel import ( + broadcast_data_batch_in_tp_cp_group, + gather_batch_from_cp_ranks, + get_batch_on_this_cp_rank, +) +from cosmos_predict1.callbacks.every_n import EveryN +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +def resize_image(image: torch.Tensor, resize_factor=0.5) -> torch.Tensor: + _, _, h, w = image.shape + new_h, new_w = int(resize_factor * h), int(resize_factor * w) + return torchvision_F.resize(image, (new_h, new_w)) + + +class VideoSamplingTeacherForcing(EveryN): + def __init__( + self, + every_n: int, + step_size: int = 1, + video_latent_shape: list = [6, 24, 40], + num_frames_to_display: int = 4, + save_folder: Optional[str] = None, + num_file_to_log: int = 8, + ): + r""" + This callback enables us to perform teacher forcing inference on the training data. + By teacher forcing, we mean providing ground truth video tokens as inputs, and simply asking the model + to predict the next tokens. The predicted next tokens are then visualized. This does not perform + autoregressive sampling. + We also upload the downsampled video frames to wandb. Downsampling is needed for wandb to work fast. + + Args: + every_n (int): Call this callback every_n steps + step_size (int): Number of steps taken for gradient accumulation. Global iteration number is + iteration // self.step_size + video_latent_shape (list): Shape of the video latent + num_frames_to_display (int): Number of frames to subsample for displaying in wandb + save_folder (str): Name of the local folder to save the video + num_file_to_log (int): Number of files to upload to wandb + """ + super().__init__(every_n, step_size) + self.save_folder = save_folder if save_folder else self.__class__.__name__ + self.video_latent_shape = video_latent_shape + self.num_frames_to_display = num_frames_to_display + self.num_file_to_log = num_file_to_log + self.rank = distributed.get_rank() + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + config_job = self.config.job + self.local_dir = f"{config_job.path_local}/{self.save_folder}" + if self.rank == 0: + os.makedirs(self.local_dir, exist_ok=True) + log.info(f"Video Teacher-Forcing Callback: local_dir: {self.local_dir}") + + @torch.inference_mode() + def every_n_impl( + self, + trainer: Trainer, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int, + ) -> None: + # Tokenize the data + + broadcast_data_batch_in_tp_cp_group(data_batch) + + input_vid = data_batch[model.tokenizer.tokenizer_config.video_tokenizer.data_key] + + dataset_name = data_batch.get("dataset_name", None) + if dataset_name is not None and dataset_name.startswith("image"): + # we disable the callback if the input video is an image batch + log.info(f"dataset_name is {dataset_name}, skip this callback") + return + + # get the caption + captions = data_batch.get("caption", None) + + # get the context embedding and mask + context = data_batch.get("context", None) + context_mask = data_batch.get("context_mask", None) + if context is not None: + context = misc.to(context, "cuda").detach().clone() + if context_mask is not None: + context_mask = misc.to(context_mask, "cuda").detach().clone() + # get the action + action = data_batch.get("action", None) + if action is not None: + action = misc.to(action, "cuda").detach().clone() + + # Input tokens + tokens, _ = model.tokenizer.tokenize(data_batch) + tokens = misc.to(tokens, "cuda").detach().clone() + skip_save_file = False + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + if self.rank != min(get_process_group_ranks(cp_group)): + skip_save_file = True + tokens = get_batch_on_this_cp_rank(tokens) + if parallel_state.get_tensor_model_parallel_world_size() > 1: + # Turn on TP + tp_group = parallel_state.get_tensor_model_parallel_group() + if self.rank != min(get_process_group_ranks(tp_group)): + skip_save_file = True + tokens_encoded_in_train = output_batch["encode_tokens"].detach() + percent_token_diff = (tokens != tokens_encoded_in_train).float().mean() + percent_token_diff = distributed.dist_reduce_tensor(percent_token_diff) + + input_tokens = tokens + + num_tokens_to_generate = np.prod(self.video_latent_shape) + + # Do a forward pass + logits = model.model.forward( + tokens, + input_pos=None, + context=context, + context_mask=context_mask, + action=action, + ) + if parallel_state.get_context_parallel_world_size() > 1: + logits = gather_batch_from_cp_ranks(logits) + input_tokens = gather_batch_from_cp_ranks(input_tokens) + + # Start position for video tokens in the vocabulary + video_token_start = self.config.model.tokenizer_config.video_tokenizer.tokenizer_offset + video_vocab_size = self.config.model.tokenizer_config.video_tokenizer.vocab_size + + # Clipping logits only to video tokens. We remove the text vocab predictions. + # This will ensure that the video tokens only correspond to the video part of the vocabulary. + logits = logits[:, :, video_token_start : video_token_start + video_vocab_size] + + # Sample with argmax token. This should be good for teacher forcing experiment. + logits = logits.contiguous() + generations = torch.argmax(logits, dim=-1) + + # For each video in the batch, subsample frames for display + batch_size = input_tokens.shape[0] + out_frames = [] + out_videos_gen = [] + out_videos_rec = [] + out_videos_gt = [] + # log the accuracy of teacher-forcing + acc = [] + loss_list = [] + + for sample_num in range(batch_size): + # Subsample the generations to the video part. + # This corresponds to the part from begin of video to end of video. + bov_token = model.tokenizer.video_special_tokens["<|begin_of_video|>"] + bov_index = input_tokens[sample_num] == bov_token + use_special_token = sum(bov_index) != 0 + if use_special_token: + bov_index = bov_index.nonzero().item() + # generations: real_token1 real_token2, ... real_token7680; total 7680 + # gen_video_tokens: real_token1 real_token2, ..., real_token7680; total 7680 + # for vis: real_token1 real_token2, ..., real_token7680; total 7680 + # for accuracy: real_token1 real_token2, ..., real_token7680; total 7680 + gen_video_tokens = generations[sample_num][bov_index : bov_index + num_tokens_to_generate] + gen_video_tokens_vis = gen_video_tokens + gen_video_tokens_acc = gen_video_tokens + logits_loss = logits[sample_num][bov_index : bov_index + num_tokens_to_generate] + else: + # generations: real_token1 real_token2, ... real_token7680 + # gen_video_tokens: real_token2 real_token3, ..., real_token7680; total 7679 + # We need different tokens for vis and accuracy compute + # for acc: real_token2 real_token3, ..., real_token7680; total 7679 + # for vis: pad_token (real_token2, ..., real_token7680); total 1 + 7679 + gen_video_tokens = generations[sample_num][ + : num_tokens_to_generate - 1 + ] # remove the last token since there is no gt + # Since the first token is not predicted, we need to add the gt first token to make sure the shape is correct + gen_video_tokens_vis = torch.cat([input_tokens[sample_num][0:1], gen_video_tokens]) + gen_video_tokens_acc = gen_video_tokens + logits_loss = logits[sample_num][: num_tokens_to_generate - 1] + + # Rearrange the video to a spatial tensor + gen_video_tokens_vis_BTHW = rearrange( + gen_video_tokens_vis.unsqueeze(0), + "B (T H W) -> B T H W", + T=self.video_latent_shape[0], + H=self.video_latent_shape[1], + W=self.video_latent_shape[2], + ) + + # for real videos, we need to skip the bov and eov tokens for decoding + if use_special_token: + # input_tokens: real_token1 real_token2 ... ... + # real_video_tokens: real_token1 real_token2 ... real_token7680; total 7680 + # for vis: real_token1 real_token2 ... real_token7680; total 7680 + # for accuracy: real_token1 real_token2 ... real_token7680; total 7680; we include real_token1 since the output prediction also includes it, see gen_video_tokens_acc above + real_video_tokens = ( + input_tokens[sample_num][bov_index + 1 : bov_index + num_tokens_to_generate + 1] - video_token_start + ) + real_video_tokens_vis = real_video_tokens + real_video_tokens_acc = real_video_tokens + else: + # input_tokens: real_token1 real_token2,... real_token7680; total 7680 + # real_video_tokens: real_token1 real_token2,... real_token7680; total 7680 + # for acc: gt start from real_token2, real_token3; total 7679, remove the first token since it is not predicted + # for vis: gt start from real_token1, real_token2; total 7680 + real_video_tokens = input_tokens[sample_num][:num_tokens_to_generate] - video_token_start + real_video_tokens_vis = real_video_tokens + real_video_tokens_acc = real_video_tokens[1:].flatten() + + real_video_tokens_vis_BTHW = rearrange( + real_video_tokens_vis.unsqueeze(0), + "B (T H W) -> B T H W", + T=self.video_latent_shape[0], + H=self.video_latent_shape[1], + W=self.video_latent_shape[2], + ) + # Calculate accuracy + correct_predictions = (gen_video_tokens_acc == real_video_tokens_acc).float() + labels = real_video_tokens_acc.clone() + + if model.config.ignore_first_num_tokens > 0: + labels[: model.config.ignore_first_num_tokens] = model.tokenizer.ignore_index + select_index = labels != model.tokenizer.ignore_index + correct_predictions = correct_predictions[select_index] + + loss = torch.nn.functional.cross_entropy( + logits_loss, labels, ignore_index=model.tokenizer.ignore_index, reduction="none" + ) + acc.append(correct_predictions.mean() * 100.0) + loss_list.append(loss.mean()) + + # Decode the predicted latents + if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0: + vid_decoded = model.tokenizer.video_tokenizer.decode(gen_video_tokens_vis_BTHW.cuda()) + else: + vid_decoded = model.tokenizer.video_tokenizer.decode_with_overlap( + gen_video_tokens_vis_BTHW.cuda(), + temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap, + ) + # normalize decoded images from [-1, 1] to [0, 1], and clip value + vid_decoded = (vid_decoded * 0.5 + 0.5).clamp_(0, 1) + vid_decoded = vid_decoded[0] + + # Decode the GT latents + if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0: + vid_rec = model.tokenizer.video_tokenizer.decode(real_video_tokens_vis_BTHW.cuda()) + else: + vid_rec = model.tokenizer.video_tokenizer.decode_with_overlap( + real_video_tokens_vis_BTHW.cuda(), + temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap, + ) + # normalize decoded image from [-1, 1] to [0, 1], and clip value + vid_rec = (vid_rec * 0.5 + 0.5).clamp_(0, 1) + vid_rec = vid_rec[0] + + vid_input = input_vid[sample_num] # [-1, 1], input_vid shape: [B, C, L, H, W] + vid_input = (vid_input * 0.5 + 0.5).clamp_(0, 1).cuda() # Convert to [0, 1], [C, L, H, W] + + # Subsample real and generated video frames + input_video_frames = vid_input.transpose(0, 1) # [L, C, H, W] + rec_video_frames = vid_rec.transpose(0, 1) + gen_video_frames = vid_decoded.transpose(0, 1) + out_videos_gen.append(gen_video_frames) + out_videos_rec.append(rec_video_frames) + out_videos_gt.append(input_video_frames) + + stride = math.ceil(rec_video_frames.shape[0] / self.num_frames_to_display) + + input_video_frames_subsampled = resize_image(input_video_frames[0::stride], resize_factor=0.5) + input_video_frames_subsampled = torchvision.utils.make_grid( + input_video_frames_subsampled, nrow=input_video_frames_subsampled.shape[0] + ) + + gt_video_frames_subsampled = resize_image(rec_video_frames[0::stride], resize_factor=0.5) + gt_video_frames_subsampled = torchvision.utils.make_grid( + gt_video_frames_subsampled, nrow=gt_video_frames_subsampled.shape[0] + ) + gen_video_frames_subsampled = resize_image(gen_video_frames[0::stride], resize_factor=0.5) + gen_video_frames_subsampled = torchvision.utils.make_grid( + gen_video_frames_subsampled, nrow=gen_video_frames_subsampled.shape[0] + ) + + out_frames.append(input_video_frames_subsampled) + out_frames.append(gt_video_frames_subsampled) + out_frames.append(gen_video_frames_subsampled) + + scaled_num_rank_to_log = ( + self.num_file_to_log + * parallel_state.get_context_parallel_world_size() + * parallel_state.get_tensor_model_parallel_world_size() + ) + if self.rank < scaled_num_rank_to_log and not skip_save_file: + local_path = f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_{self.rank:04d}.jpg" + out_image_grid = torchvision.utils.make_grid(out_frames, nrow=1, padding=0, normalize=False) + os.makedirs(os.path.dirname(local_path), exist_ok=True) + torchvision.utils.save_image(out_image_grid, local_path) + + # Log to wandb + avg_acc = distributed.dist_reduce_tensor(torch.stack(acc).mean()).item() + avg_loss = distributed.dist_reduce_tensor(torch.stack(loss_list).mean()).item() + log_info = "" + if "acc" in output_batch: + log_info = f"train acc: {(output_batch['acc'].mean().item()):.6f}%" + if percent_token_diff is not None: + log_info += f"; percent_token_diff_train_val: {percent_token_diff.item() * 100:.6f}%" + log.info( + f"Eval iteration {iteration} teacher-forcing accuracy: {avg_acc:.6f}%, loss: {avg_loss:.4f}; {log_info}" + ) + if self.rank == 0 and wandb.run: + local_files = glob.glob(f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_*.jpg") + local_files = sorted(local_files)[: self.num_file_to_log] + if captions is None: + captions = ["vid_frames_teacher_forcing"] * len(local_files) + for local_path, caption in zip(local_files, captions): + wandb.log( + {"frames": [wandb.Image(local_path, caption=caption)]}, + step=iteration, + ) + + wandb.log({"eval/teacher_forcing_acc": avg_acc}, step=iteration) + wandb.log({"eval/teacher_forcing_loss": avg_loss}, step=iteration) + if percent_token_diff is not None: + wandb.log({"eval/percent_token_diff_train_val": percent_token_diff.item() * 100}, step=iteration) diff --git a/cosmos_predict1/autoregressive/configs/__init__.py b/cosmos_predict1/autoregressive/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/configs/base/__init__.py b/cosmos_predict1/autoregressive/configs/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/configs/base/callbacks.py b/cosmos_predict1/autoregressive/configs/base/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..040f326f221febac7af54f7d7a64876a9fc030ef --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/callbacks.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.autoregressive.callbacks.video_sampling_teacher_forcing import VideoSamplingTeacherForcing +from cosmos_predict1.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.lazy_config import LazyCall as L + +BASIC_CALLBACKS = dict( + progress_bar=L(ProgressBarCallback)(), + grad_clip=L(GradClip)(clip_norm=1.0, fsdp_enabled="${model.model_config.fsdp_enabled}", model_key="model"), +) + +VIDEO_TEACHER_FORCING_CALLBACK = dict( + vid_sampling_tf=L(VideoSamplingTeacherForcing)( + every_n=500, + video_latent_shape="${model.model_config.video_latent_shape}", + num_frames_to_display=4, + save_folder="video_sampling_teacher_forcing", + ) +) diff --git a/cosmos_predict1/autoregressive/configs/base/dataloader.py b/cosmos_predict1/autoregressive/configs/base/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..69458e450a7fb5d85ea25c2514e975d5540a108f --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/dataloader.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig +from cosmos_predict1.autoregressive.datasets.video_dataset import VideoDataset +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import LazyCall as L + +DATALOADER_OPTIONS = {} + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +def dataloader_register(key): + log.info(f"registering dataloader {key}...") + + def decorator(func): + DATALOADER_OPTIONS[key] = func + return func + + return decorator + + +@dataloader_register("tealrobot_video") +def get_tealrobot_video( + batch_size: int = 1, + dataset_dir: str = "datasets/cosmos_nemo_assets/videos/", + sequence_interval: int = 1, + num_frames: int = 33, + video_size: list[int, int] = [640, 848], + start_frame_interval: int = 1, +): + dataset = L(VideoDataset)( + config=VideoDatasetConfig( + dataset_dir=dataset_dir, + sequence_interval=sequence_interval, + num_frames=num_frames, + video_size=video_size, + start_frame_interval=start_frame_interval, + ) + ) + return L(DataLoader)( + dataset=dataset, + sampler=L(get_sampler)(dataset=dataset), + batch_size=batch_size, + drop_last=True, + pin_memory=True, + num_workers=8, + ) diff --git a/cosmos_predict1/autoregressive/configs/base/dataset.py b/cosmos_predict1/autoregressive/configs/base/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8e24fa535a0abc0b7fde86ad302a960bc6bf28 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/dataset.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset config class.""" + +import attrs + +from cosmos_predict1.utils.config import make_freezable + + +@make_freezable +@attrs.define(slots=False) +class VideoDatasetConfig: + """ + Args: + dataset_dir (str): Base path to the dataset directory + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + video_size (list): Target size [H,W] for video frames + start_frame_interval (int): Interval between starting frames of sequences + """ + + dataset_dir: str = "datasets/cosmos_nemo_assets/videos/" + sequence_interval: int = 1 + num_frames: int = 33 + video_size: list[int, int] = [640, 848] + start_frame_interval: int = 1 diff --git a/cosmos_predict1/autoregressive/configs/base/model.py b/cosmos_predict1/autoregressive/configs/base/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f72feb258ee3233eeadbd3afd218762648211ea5 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/model.py @@ -0,0 +1,318 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import attrs + +from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig +from cosmos_predict1.utils import config + +_ACTION_DIM = 8 +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define +class ModelConfig: + """ + A class to hold model configuration arguments. + + Args: + dim (int): The dimensionality of the input and output of each transformer block. + n_layers (int): Number of layers in the transformer. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to + `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention. + head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads. + vocab_size (int): Vocabulary size. + ffn_hidden_size (int): Hidden size for feedforward network. + norm_eps (float): Epsilon value for normalization. + rope_theta (float): Theta value for rotary positional embeddings. + apply_abs_pos_emb (bool): Whether to apply absolute position embeddings. + max_batch_size (int): Maximum batch size for inference. + max_seq_len (int): Maximum sequence length for input text. + fuse_qkv (bool): Whether to fuse QKV in attention. Defaults to True. + causal_mask (bool): Whether to use causal mask. Defaults to True. + norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm". + precision (str): Data type for the model. + use_qk_normalization (bool): Whether to enable QK normalization. + tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1. + ckpt_dir (str): Checkpoint directory. + ckpt_path (str): Checkpoint path. + apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension). + yarn_scale (Optional[float]): Scale factor for YaRN. + yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code) + yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code) + original_seq_len (Optional[int]): Original sequence length. + vision_encoder (Optional[str]): Vision encoder name. + mm_projector (Optional[str]): Multi-modal projector name. + vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4. + rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "3D". + pytorch_rope_version (Optional[str]): Version of the PyTorch RoPE implementation. Choices: "v1", "v2". + original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3. + insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer. + insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers. + context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim. + num_video_frames (Optional[int]): Number of video frames. + video_height (Optional[int]): Raw video pixel height dimension. + video_width (Optional[int]): Raw video pixel width dimension. + video_latent_shape (Optional[list]): Video tokenizer output dimension, in (T,H,W). + """ + + dim: int = attrs.field(default=4096) + n_layers: int = attrs.field(default=32) + n_heads: int = attrs.field(default=32) + n_kv_heads: Optional[int] = attrs.field(default=8) + head_dim: Optional[int] = attrs.field(default=None) + vocab_size: int = attrs.field(default=128256) + ffn_hidden_size: int = attrs.field(default=14336) + norm_eps: float = attrs.field(default=1e-5) + rope_theta: float = attrs.field(default=500000) + apply_abs_pos_emb: bool = attrs.field(default=False) + max_batch_size: int = attrs.field(default=1) + max_seq_len: int = attrs.field(default=8192) + fuse_qkv: bool = attrs.field(default=False) + causal_mask: bool = attrs.field(default=True) + norm_type: str = attrs.field(default="rmsnorm") + precision: str = attrs.field(default="bfloat16") + use_qk_normalization: bool = False + tokenizer: Optional[TokenizerConfig] = None + tensor_model_parallel_size: int = attrs.field(default=1) + ckpt_dir: Optional[str] = attrs.field(default=None) + ckpt_path: Optional[str] = attrs.field( + default=None + ) # If not None, load the model from this path instead of ckpt_dir + apply_yarn: Optional[bool] = attrs.field(default=False) + yarn_scale: Optional[float] = attrs.field(default=None) + yarn_beta_fast: Optional[int] = attrs.field(default=None) + yarn_beta_slow: Optional[int] = attrs.field(default=None) + original_seq_len: Optional[int] = attrs.field(default=None) + vision_encoder: Optional[str] = attrs.field(default=None) + vision_encoder_in_channels: Optional[int] = attrs.field(default=3) + mm_projector: Optional[str] = attrs.field(default=None) + rope_dim: Optional[str] = attrs.field(default="1D") + pytorch_rope_version: Optional[str] = attrs.field(default="v2") + original_latent_shape: Optional[list] = None + pad_to_multiple_of: Optional[int] = None + vision_encoder_in_channels: Optional[int] = attrs.field(default=3) + insert_cross_attn: bool = False + insert_cross_attn_every_k_layers: int = 1 + context_dim: Optional[int] = attrs.field(default=1024) + # For video training + num_video_frames: Optional[int] = None + # Raw video pixel dimension + video_height: Optional[int] = None + video_width: Optional[int] = None + # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact + video_latent_shape: Optional[list] = None + + def __getitem__(self, item): + return getattr(self, item) + + +@attrs.define +class TrainingModelConfig: + """ + A class to hold model configuration arguments. + + Args: + dim (int): The dimensionality of the input and output of each transformer block. + n_layers (int): Number of layers in the transformer. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to + `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention. + head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads. + vocab_size (int): Vocabulary size. + multiple_of (int): Ensures the hidden layer size is a multiple of this value for SwiGLU activation. + ffn_dim_multiplier (Optional[float]): Multiplier for feedforward network dimension. + ffn_hidden_size (Optional[int]): Hidden size for feedforward network. If None, use ffn_dim_multiplier to compute it. + norm_eps (float): Epsilon value for normalization. + rope_theta (float): Theta value for rotary positional embeddings. + apply_abs_pos_emb (bool): Whether to apply absolute position embeddings. + max_batch_size (int): Maximum batch size for inference (determines KV cache size). + max_seq_len (int): Maximum sequence length for input text (determines KV cache size). + fuse_qkv (bool): Whether to fuse QKV in attention. Flag for the pytorch backend. + causal_mask (bool): Whether to use causal mask. Defaults to True. + flash_attn (bool): Whether to use Flash attention. + norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm". + backend (str): Backend for the model. + precision (str): Data type for the model. + ema (config.EMAConfig): Configuration for exponential moving average. + embedding_dropout(float): Dropout rate for the embedding layer. + attention_dropout(float): Dropout rate for attention. + hidden_dropout(float): Dropout after the attention and feed-forward layers (following TransformerEngine's + implementation in its TransformerLayer class). + use_qk_normalization (bool): Whether to enable QK normalization. + inference (bool): Whether the model is used for inference. + act_ckpt_enabled (bool): Whether to enable activation checkpointing. + fsdp_enabled (bool): Whether to enable FSDP. + fsdp (LazyDict): Configuration for FSDP. + ckpt_dir (str): Checkpoint directory. + ckpt_path (str): Checkpoint path. + cache_dir (str): Cache directory. + apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension). + yarn_scale (Optional[float]): Scale factor for YaRN. + yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code) + yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code) + original_seq_len (Optional[int]): Original sequence length. + depth_init (bool): If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the + total number of transformer blocks. Defaults to `True` (following the TorchTitan implementation of Llama3). + context_parallel_size (int): Context parallel size. Defaults to 1. + tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1. + sequence_parallel (bool): Whether to use sequence parallelism. Defaults to False. + set_parallel_mode (bool): It is a boolean flag used by TransformerEngine to handle Tensor Parallelism. + Essentially, it is equivalent to `tensor_model_parallel_size > 1`. Defaults to `False`. + attention_tp (bool): Whether to use tensor parallelism for attention layers. + mm_projector (Optional[str]): Multimodal projector used for vision-language modeling. Defaults to None. + Choices: "identity", "linear", "mlp", "mlp_downsample". + video_latent_shape (Optional[list]): Shape of the video latent tensor. [T, H, W] + image_latent_shape (Optional[list]): Shape of the image latent tensor. [H, W] + num_video_frames (Optional[int]): Number of video frames. + rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D". + pytorch_rope_version (Optional[str]): Version of the RoPE for the `pytorch` backend. "v1" is the Llama implementation, and "v2" is HuggingFace/TransformerEngine implementation. + original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + peft_last_n_layers (Optional[int]): Number of last few layers to fine-tune in Parameter Efficient Fine-Tuning (PEFT). When this and peft_every_n_layers are both 0, it means all layers are fine-tuned (FFT). + peft_every_n_layers (Optional[int]): In Parameter Efficient Fine-Tuning (PEFT), every n layers are unfrozen and can be trained (in flamingo style). When this and peft_last_n_layers are both 0, + it means all layers are fine-tuned (FFT). For example, for a 40 layer model, n=8 means training layers 7, 15, 23, 31, 39, which includes the final layer. + It is advised to pick n such that the final layer is included. + freeze_vision_encoder (bool): Whether to freeze the vision encoder in vision-language model training. Defaults to False. + vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4. + insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer. + insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers. + context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim. + finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn). + finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn). + use_action_condition (bool): Whether to use the robot action condition. + action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp". + action_dim (Optional[int]): The dimensionality of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]). + action_embedding_dim (Optional[int]): The dimensionality of the robot action embedding. + group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal". + sync_1d_parameters (bool): Whether to synchronize layernorm parameters (1D) across tensor parallel ranks (default True). + Note: this is to ensure all TP-ranks have the same layernorm parameters. + z_loss_coeff (float): The coefficient for the z-loss. + insert_medusa_head (bool): Whether to insert the Medusa head. + ft_medusa_option (str): Options on which layers to finetune, choices like: + "fft": fully fine-tune both medusa heads and all LLM backbone; + "head": fine-tune medusa heads; + "head_out": fine-tune medusa heads, and the output layer; + "head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone. + medusa_num_heads (int): Number of heads in the Medusa head. + medusa_num_layers (int): Number of layers in the Medusa head. + medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1. + zero_init_cross_attn_proj (bool): Whether to initialize the cross-attn proj layer with zeros (default False). + concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False). + """ + + dim: int = attrs.field(default=4096) + n_layers: int = attrs.field(default=32) + n_heads: int = attrs.field(default=32) + n_kv_heads: Optional[int] = attrs.field(default=8) + head_dim: Optional[int] = attrs.field(default=None) + vocab_size: int = attrs.field(default=128256) + multiple_of: int = attrs.field(default=1024) # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = attrs.field(default=1.3) + ffn_hidden_size: Optional[int] = attrs.field(default=None) + norm_eps: float = attrs.field(default=1e-5) + rope_theta: float = attrs.field(default=500000) + apply_abs_pos_emb: bool = attrs.field(default=False) + max_batch_size: int = attrs.field(default=1) + max_seq_len: int = attrs.field(default=8192) + fuse_qkv: bool = attrs.field(default=False) + causal_mask: bool = attrs.field(default=True) + flash_attn: bool = attrs.field(default=True) + norm_type: str = attrs.field(default="rmsnorm") + backend: str = attrs.field(default="pytorch") + precision: str = attrs.field(default="bfloat16") + ema: config.EMAConfig = config.EMAConfig(enabled=False) + embedding_dropout: float = 0.0 + attention_dropout: float = 0.0 + hidden_dropout: float = 0.0 + use_qk_normalization: bool = False + tokenizer: Optional[TokenizerConfig] = None + inference: bool = False + act_ckpt_enabled: bool = False + fsdp_enabled: bool = False + context_parallel_size: int = attrs.field(default=1) + tensor_model_parallel_size: int = attrs.field(default=1) + sequence_parallel: bool = attrs.field(default=False) + set_parallel_mode: bool = attrs.field(default=False) + fsdp: LazyDict = LazyDict( + dict( + policy="auto", # choices: ["size", "auto"] + min_num_params=1024, # Used as policy == "size" + sharding_strategy="hybrid", # Choices: ["full", "hybrid"]. "full" means sharding_group_size = world_size + sharding_group_size=8, # If None, defaults to min(world_size, 8). Recommends 8 for training on 8-GPU nodes. + ) + ) + ckpt_dir: Optional[str] = attrs.field(default="") + ckpt_path: Optional[str] = attrs.field( + default=None + ) # If not None, load the model from this path instead of ckpt_dir + cache_dir: Optional[str] = attrs.field(default="/project/cosmos/ar/cache") + apply_yarn: Optional[bool] = attrs.field(default=False) + yarn_scale: Optional[float] = attrs.field(default=None) + yarn_beta_fast: Optional[int] = attrs.field(default=None) + yarn_beta_slow: Optional[int] = attrs.field(default=None) + original_seq_len: Optional[int] = attrs.field(default=None) + depth_init: bool = attrs.field(default=True) + ignore_first_num_tokens: int = 0 + z_loss_coeff: float = 1e-4 + attention_tp: bool = False + vision_encoder: Optional[str] = attrs.field(default=None) + mm_projector: Optional[str] = attrs.field(default=None) + rope_dim: Optional[str] = attrs.field(default="1D") + pytorch_rope_version: Optional[str] = attrs.field(default="v2") + original_latent_shape: Optional[list] = None + pad_to_multiple_of: Optional[int] = None + peft_last_n_layers: Optional[int] = attrs.field(default=0) + peft_every_n_layers: Optional[int] = attrs.field(default=0) + freeze_vision_encoder: bool = False + vision_encoder_in_channels: Optional[int] = attrs.field(default=3) + insert_cross_attn: bool = False + insert_cross_attn_every_k_layers: int = 1 + context_dim: Optional[int] = attrs.field(default=1024) + finetune_layers_with_cross_attn: bool = False + finetune_layers_without_cross_attn: bool = False + use_action_condition: bool = False + action_embedding_mode: Optional[str] = attrs.field(default="mlp") + action_dim: Optional[int] = attrs.field(default=_ACTION_DIM) + action_embedding_dim: Optional[int] = attrs.field(default=1024) + group_causal_mask_mode: Optional[str] = attrs.field(default=None) + sync_1d_parameters: bool = True + # hyper-parameters for the medusa head configs + insert_medusa_head: bool = False + ft_medusa_option: str = "fft" + medusa_num_heads: int = 7 + medusa_num_layers: int = 1 + medusa_concat_heads: bool = True + # For video training + num_video_frames: Optional[int] = None + # Raw video pixel dimension + video_height: Optional[int] = None + video_width: Optional[int] = None + # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact + video_latent_shape: Optional[list] = None + # For image training + image_latent_shape: Optional[list] = None + # For robot training (action) + zero_init_cross_attn_proj: bool = False + # For robot training (action) + concat_action_to_context: bool = False + + def __getitem__(self, item): + return getattr(self, item) diff --git a/cosmos_predict1/autoregressive/configs/base/model_config.py b/cosmos_predict1/autoregressive/configs/base/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..53442dc6faea3346b59bb860014cde1373e236bd --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/model_config.py @@ -0,0 +1,718 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Callable, List, Optional + +import torch +from megatron.core import ModelParallelConfig + +from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TrainingModelConfig +from cosmos_predict1.autoregressive.configs.base.tokenizer import ( + TextTokenizerConfig, + TokenizerConfig, + VideoTokenizerConfig, + create_discrete_video_fsq_tokenizer_state_dict_config, +) +from cosmos_predict1.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer +from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer +from cosmos_predict1.autoregressive.training.model import AutoRegressiveTrainingModel +from cosmos_predict1.utils import log +from cosmos_predict1.utils.config import EMAConfig +from cosmos_predict1.utils.lazy_config import LazyCall as L + +# Common architecture specifications +BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336} +COSMOS_ARCHITECTURES = { + "1b": { + "n_layers": 16, + "dim": 2048, + "n_heads": 32, + }, + "4b": { + "n_layers": 16, + "dim": 4096, + "n_heads": 32, + }, + "12b": { + "n_layers": 40, + "dim": 5120, + "n_heads": 32, + "head_dim": 128, + }, +} + +COSMOS_YARN_CONFIG = { + "original_latent_shape": [3, 40, 64], + "apply_yarn": True, + "yarn_beta_fast": 4, + "yarn_beta_slow": 1, + "yarn_scale": 2, +} + +# Llama3 architecture specifications for different model sizes +LLAMA3_ARCHITECTURES = { + "8b": { + "n_layers": 32, + "dim": 4096, + "n_heads": 32, + "ffn_hidden_size": 14336, + }, +} +# Llama3.1 uses YaRN for long context support (context of 128k tokens) +LLAMA_YARN_CONFIG = { + "apply_yarn": True, + "yarn_scale": 8, + "yarn_beta_fast": 4, + "yarn_beta_slow": 1, +} + +# Mistral architecture specifications for different model sizes +MISTRAL_ARCHITECTURES = { + "12b": { + "n_layers": 40, + "dim": 5120, + "n_heads": 32, + "ffn_hidden_size": 14336, + "head_dim": 128, + }, +} + +PIXTRAL_VISION_ARCHITECTURES = { + "12b": {"vision_encoder": "pixtral-12b-vit", "mm_projector": "mlp"}, +} + + +def get_model_arch_specs(model_size: str, model_family: str = "mistral", pretrained: bool = False) -> dict: + """ + Get the model architecture specifications for the given model size, model family and pretrained status. + + Args: + model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", etc. + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral" + pretrained (bool): Whether to load pretrained weights. + + Returns: + dict: A dictionary containing the model architecture specifications. + """ + arch_specs = copy.deepcopy(BASE_CONFIG) + model_size = model_size.lower() + if model_family.startswith("cosmos"): + arch_specs.update(COSMOS_ARCHITECTURES[model_size]) + elif model_family.startswith("llama"): + arch_specs.update(LLAMA3_ARCHITECTURES[model_size]) + elif model_family in ["mistral", "pixtral"]: + arch_specs.update(MISTRAL_ARCHITECTURES[model_size]) + if model_family == "pixtral": + arch_specs.update(PIXTRAL_VISION_ARCHITECTURES[model_size]) + else: + raise ValueError(f"Model family {model_family} is not supported.") + + if pretrained: + if model_family == "cosmos": + if model_size == "12b": + arch_specs.update(COSMOS_YARN_CONFIG) + log.debug(f"Using YaRN for RoPE extension with config: {COSMOS_YARN_CONFIG}") + else: + pass + elif model_family in ["llama", "llama3"]: + pretrained_specs = { + "rope_theta": 500000, + "max_seq_len": 8192, + "vocab_size": 128256, + } + arch_specs.update(pretrained_specs) + elif model_family == "llama3.1": + pretrained_specs = { + "rope_theta": 500000, + "max_seq_len": 131072, + "original_seq_len": 8192, + "vocab_size": 128256, + **LLAMA_YARN_CONFIG, + } + arch_specs.update(pretrained_specs) + elif model_family == "mistral": + assert model_size == "12b", "We only support Mistral-Nemo-12B model." + pretrained_specs = { + "rope_theta": 1000000, + "max_seq_len": 128000, + "vocab_size": 131072, + } + arch_specs.update(pretrained_specs) + elif model_family == "pixtral": + assert model_size == "12b", "We only support Pixtral 12B model." + pretrained_specs = {"rope_theta": 1000000000, "max_seq_len": 128000, "vocab_size": 131072} + arch_specs.update(pretrained_specs) + else: + raise ValueError(f"Model family {model_family} doesn't have a pretrained config.") + + return arch_specs + + +def create_text_model_config( + model_ckpt_path: str, + tokenizer_path: str, + tensor_model_parallel_size: int = 1, + model_family: str = "mistral", + model_size: str = "12b", + is_instruct_model: bool = True, + max_seq_len: int = None, + max_batch_size: int = 1, + rope_dim: str = "1D", + add_special_tokens: bool = True, + pytorch_rope_version: str = None, +) -> dict: + """Create a text model for training or inference. + Args: + model_ckpt_path (str): Path to the model checkpoint. + tokenizer_path (str): Path to the tokenizer folder. + tensor_model_parallel_size (int): Number of tensor model parallel groups. + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral". + model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", "8b", "72b", etc. + is_instruct_model (bool): Whether the model is an instruct model. + inference (bool): Whether to create the model for inference. + max_seq_len (int): Maximum sequence length. + max_batch_size (int): Maximum batch size. + rope_dim (str): RoPE dimension. Choices: "1D", "3D". + add_special_tokens (bool): Whether to add special tokens. + Returns: + dict: A dictionary containing the model configuration, which can be used to instantiate the model object. + """ + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True) + if max_seq_len is not None: + # Override the max_seq_len if provided + model_arch_specs["max_seq_len"] = max_seq_len + if pytorch_rope_version is not None: + model_arch_specs["pytorch_rope_version"] = pytorch_rope_version + model_config = ModelConfig( + max_batch_size=max_batch_size, + precision="bfloat16", + ckpt_path=model_ckpt_path, + use_qk_normalization=False, + tensor_model_parallel_size=tensor_model_parallel_size, + rope_dim=rope_dim, + **model_arch_specs, + ) + + tokenizer_config = TokenizerConfig( + text_tokenizer=TextTokenizerConfig( + config=L(TextTokenizer)( + model_family=model_family, + is_instruct_model=is_instruct_model, + local_path=tokenizer_path, + ), + data_key="text", + tokenizer_offset=model_config.vocab_size, + tokenize_here=False, + vocab_size=model_config.vocab_size, + ), + seq_len=model_config.max_seq_len, + training_type="text_only", + add_special_tokens=add_special_tokens, + ) + return model_config, tokenizer_config + + +def create_vision_language_model_config( + model_ckpt_path: str, + tokenizer_ckpt_path: str, + tensor_model_parallel_size: int = 1, + model_family: str = "pixtral", + model_size: str = "12b", + is_instruct_model: bool = True, + max_batch_size: int = 1, + rope_dim: str = "1D", + add_special_tokens: bool = True, + max_seq_len: int = None, + vision_encoder_in_channels: int = 3, + fuse_qkv: bool = False, + pytorch_rope_version: str = None, +) -> dict: + """Create a vision-language model for training or inference. + Args: + model_ckpt_path (str): Path to the model checkpoint. + tokenizer_ckpt_path (str): Path to the tokenizer checkpoint. + tensor_model_parallel_size (int): Number of tensor model parallel groups. + model_family (str): Model family. Choices: "pixtral". + model_size (str): Model size. Choices: "12b". + is_instruct_model (bool): Whether the model is an instruct model. + rope_dim (str): RoPE dimension. Choices: "1D". + add_special_tokens (bool): Whether to add special tokens. + max_seq_len (int): Maximum sequence length. + vision_encoder_in_channels (int): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4 channel images where last channel is binary mask, set this to 4. + fuse_qkv (bool): Whether to fuse the QKV linear layers. + Returns: + dict: A dictionary containing the model configuration, which can be used to instantiate the model object. + """ + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True) + if max_seq_len is not None: + # Override the max_seq_len if provided + model_arch_specs["max_seq_len"] = max_seq_len + if pytorch_rope_version is not None: + model_arch_specs["pytorch_rope_version"] = pytorch_rope_version + + model_config = ModelConfig( + max_batch_size=max_batch_size, + precision="bfloat16", + ckpt_path=model_ckpt_path, + use_qk_normalization=False, + tensor_model_parallel_size=tensor_model_parallel_size, + rope_dim=rope_dim, + vision_encoder_in_channels=vision_encoder_in_channels, + fuse_qkv=fuse_qkv, + **model_arch_specs, + ) + # Vision-language tokenizer + tokenizer_config = TokenizerConfig( + text_tokenizer=TextTokenizerConfig( + config=L(ImageTextTokenizer)( + model_family=model_family, + is_instruct_model=is_instruct_model, + image_processor_path=tokenizer_ckpt_path, + tokenizer_path=tokenizer_ckpt_path, + ), + data_key="image_text_interleaved", + tokenizer_offset=model_config.vocab_size, + tokenize_here=False, + vocab_size=model_config.vocab_size, + ), + seq_len=model_config.max_seq_len, + training_type="image_text_interleaved", + add_special_tokens=add_special_tokens, + ) + return model_config, tokenizer_config + + +def create_video2world_model_config( + model_ckpt_path: str, + tokenizer_ckpt_path: str, + tensor_model_parallel_size: int = 1, + model_family: str = "cosmos", + model_size: str = "4b", + pixel_chunk_duration: int = 9, + num_video_frames: int = 36, + compression_ratio: List[int] = [8, 16, 16], + original_seq_len: int = 8192, + num_condition_latents_t: int = 1, + num_tokens_to_ignore: int = -1, + batch_size: int = 2, + video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config, + rope_dim: str = "3D", + add_special_tokens: bool = True, + video_height: int = 384, + video_width: int = 640, + use_qk_normalization: bool = True, + insert_cross_attn: bool = False, + insert_cross_attn_every_k_layers: int = 1, + context_dim: int = 1024, + training_type: str = "video_to_video", + pad_to_multiple_of: Optional[int] = 64, + vocab_size: int = 64000, + apply_abs_pos_emb: bool = False, +) -> dict: + """Create a video-to-world model config. + Args: + tensor_model_parallel_size (int): Number of tensor model parallel groups. + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral". + model_size (str): Model size. Choices: "1b", "8b", "3b". + pixel_chunk_duration (int): Number of frames in each chunk. + num_video_frames (int): Number of video frames. + compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8]. + original_seq_len (int): Original sequence length. + apply_yarn (bool): Whether to apply YaRN for long context scaling. + yarn_beta_fast (Optional[int]): Fast beta for YaRN. + yarn_beta_slow (Optional[int]): Slow beta for YaRN. + yarn_scale (Optional[int]): Scale factor for ctx extension. + use_qk_normalization (bool): Whether to use Query-Key normalization. + training_type (str): Type of training task. + batch_size (int): Batch size. + video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config + video_tokenizer_version (str): Version of the video tokenizer. + num_condition_latents_t (int): Number of conditioning latent channels + num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence + video_height (int): Height of the video frame. Defaults to 384. + video_width (int): Width of the video frame. Defaults to 640. + rope_dim (str): RoPE dimension. Choices: "1D", "3D". + add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE. + pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64. + vocab_size (int): Vocabulary size. + apply_abs_pos_emb (bool): Whether to apply absolute positional embeddings. + Returns: + dict: A dictionary containing the model configuration representing the model object, can be instantiated. + """ + assert ( + pixel_chunk_duration % compression_ratio[0] == 1 + ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})" + latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1 + latent_height = video_height // compression_ratio[1] + latent_width = video_width // compression_ratio[2] + # Do some math to compute the video latent shape and sequence length + assert ( + num_video_frames % pixel_chunk_duration == 0 + ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}" + video_latent_shape = [ + num_video_frames // pixel_chunk_duration * latent_chunk_duration, + latent_height, + latent_width, + ] + # product of video_latent_shape + num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2] + if add_special_tokens: + seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3 + seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64 + # for text to video, we need to add token to indicate the start of the video + elif training_type == "text_to_video": + seq_len = num_token_video_latent + 1 + else: + seq_len = num_token_video_latent + + if seq_len % pad_to_multiple_of != 0: + # Round up to the nearest multiple of pad_to_multiple_of + seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True) + + # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss + # If num_tokens_to_ignore is specified, use it. + # Else compute it from num_condition_latents_t + if num_tokens_to_ignore < 0: + num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t + if not add_special_tokens and num_condition_latents_t > 0: + # If there are no special tokens (bov), do a -1 so that you can compute the loss + # from the first token of the next chunk + num_tokens_to_ignore -= 1 + + model_config = ModelConfig( + video_height=video_height, + video_width=video_width, + max_seq_len=seq_len, + max_batch_size=batch_size, + precision="bfloat16", + ckpt_path=model_ckpt_path, + use_qk_normalization=use_qk_normalization, + vocab_size=64000, + original_seq_len=original_seq_len, + tensor_model_parallel_size=tensor_model_parallel_size, + video_latent_shape=video_latent_shape, + num_video_frames=num_video_frames, + rope_dim=rope_dim, + pad_to_multiple_of=pad_to_multiple_of, + insert_cross_attn=insert_cross_attn, + insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers, + context_dim=context_dim, + apply_abs_pos_emb=apply_abs_pos_emb, + **model_arch_specs, + ) + + video_tokenizer_config = video_tokenizer_config_creator( + tokenizer_ckpt_path, pixel_chunk_duration, compression_ratio + ) + tokenizer_config = TokenizerConfig( + text_tokenizer=None, + video_tokenizer=VideoTokenizerConfig( + config=video_tokenizer_config, + data_key="video", + tokenizer_offset=0, # Since there is no text embeddings in the model. Note this only apply when the model is trained from scratch. If we use text pretrained model, the offset will be vocab_size of text token. + tokenize_here=True, + max_seq_len=num_token_video_latent, + vocab_size=vocab_size, + ), + seq_len=seq_len, + training_type=training_type, + add_special_tokens=add_special_tokens, + pad_to_multiple_of=pad_to_multiple_of, + ) + return model_config, tokenizer_config + + +def create_video2world_model( + tensor_model_parallel_size: int = 1, + context_parallel_size: int = 1, + shard_checkpoint: bool = False, + model_family: str = "cosmos", + model_size: str = "1b", + backend: str = "pytorch", + pixel_chunk_duration: int = 9, + num_video_frames: int = 36, + compression_ratio: List[int] = [8, 16, 16], + original_seq_len: int = 8192, + apply_yarn: bool = False, + yarn_beta_fast: Optional[int] = None, + yarn_beta_slow: Optional[int] = None, + yarn_scale: Optional[int] = None, + num_condition_latents_t: int = 1, + num_tokens_to_ignore: int = -1, + batch_size: int = 1, + fsdp_enabled: bool = False, + act_ckpt_enabled: bool = False, + video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config, + rope_dim: str = "3D", + add_special_tokens: bool = False, + video_height: int = 384, + video_width: int = 640, + original_latent_shape: Optional[List[int]] = None, + use_qk_normalization: bool = True, + sequence_parallel: bool = False, + insert_cross_attn: bool = False, + insert_cross_attn_every_k_layers: int = 1, + context_dim: int = 1024, + finetune_layers_with_cross_attn: bool = False, + finetune_layers_without_cross_attn: bool = False, + use_action_condition: bool = False, + action_embedding_mode: Optional[str] = "mlp", + action_dim: int = 8, # ACTION_DIM, + action_embedding_dim: int = 1024, + group_causal_mask_mode: Optional[str] = None, + training_type: str = "video_to_video", + pad_to_multiple_of: Optional[int] = 1, + z_loss_coeff: float = 1e-4, + temporal_overlap: int = 0, + embedding_dropout: float = 0.0, + insert_medusa_head: bool = False, + ft_medusa_option: str = "fft", + medusa_num_heads: int = 7, + medusa_num_layers: int = 1, + medusa_concat_heads: bool = True, + fuse_qkv: bool = False, + zero_init_cross_attn_proj: bool = False, + concat_action_to_context: bool = False, + tokenizer_ckpt_path: str = "checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/ema.jit", +) -> dict: + """Create a video-to-video model for training. + Args: + tensor_model_parallel_size (int): Number of tensor model parallel groups. + context_parallel_size (int): Number of context parallel groups. + model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral". + model_size (str): Model size. Choices: "1b", "8b", "3b". + backend (str): Backend for the model. Choices: "pytorch", "transformer_engine". + pixel_chunk_duration (int): Number of frames in each chunk. + num_video_frames (int): Number of video frames. + compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8]. + original_seq_len (int): Original sequence length. + apply_yarn (bool): Whether to apply YaRN for long context scaling. + yarn_beta_fast (Optional[int]): Fast beta for YaRN. + yarn_beta_slow (Optional[int]): Slow beta for YaRN. + yarn_scale (Optional[int]): Scale factor for ctx extension. + fsdp_enabled (bool): Whether Fully Sharded Data Parallel (FSDP) is enabled. + act_ckpt_enabled (bool): Whether activation checkpointing is enabled. + use_qk_normalization (bool): Whether to use Query-Key normalization. + training_type (str): Type of training task. + batch_size (int): Batch size. + video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config + video_tokenizer_version (str): Version of the video tokenizer. + num_condition_latents_t (int): Number of conditioning latent channels + num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence + video_height (int): Height of the video frame. Defaults to 384. + video_width (int): Width of the video frame. Defaults to 640. + rope_dim (str): RoPE dimension. Choices: "1D", "2D", "3D". + add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE. + original_latent_shape (list): Original latent shape before RoPE scaling. + sequence_parallel (bool): Whether to enable sequence parallelism. + insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer. + insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers. + context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim. + finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn). + finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn). + use_action_condition (bool): Whether to use action condition. + action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp". + action_dim (int): Dimension of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]). + action_embedding_dim (int): Dimension of the action embedding. + group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal". + pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64. + z_loss_coeff (float): Coefficient for the z loss. + temporal_overlap (int): Temporal overlap in the latent space. + embedding_dropout (float): Dropout rate for the embeddings. + insert_medusa_head (bool): Whether to insert the Medusa head. + ft_medusa_option (str): Options on which layers to finetune, choices like: + "fft": fully fine-tune both medusa heads and all LLM backbone; + "head": fine-tune medusa heads; + "head_out": fine-tune medusa heads, and the output layer; + "head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone. + medusa_num_heads (int): Number of heads in the Medusa head. + medusa_num_layers (int): Number of layers in the Medusa head. + medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1. + fuse_qkv (bool): Whether to fuse the QKV linear layers. + zero_init_cross_attn_proj (bool): Whether to zero-initialize the cross-attention projection weights (default False). + concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False). + Returns: + dict: A dictionary containing the model configuration representing the model object, can be instantiated. + """ + assert ( + pixel_chunk_duration % compression_ratio[0] == 1 + ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})" + latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1 + latent_height = video_height // compression_ratio[1] + latent_width = video_width // compression_ratio[2] + # Compute the video latent shape and sequence length + if temporal_overlap == 0: + assert ( + num_video_frames % pixel_chunk_duration == 0 + ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}" + video_latent_shape = [ + num_video_frames // pixel_chunk_duration * latent_chunk_duration, + latent_height, + latent_width, + ] + + else: + # Calculate temporal overlap in the latent space + temporal_overlap_latent = temporal_overlap // compression_ratio[0] + + # Calculate the effective number of latent chunks for the video + latent_chunks = (num_video_frames - temporal_overlap) // (pixel_chunk_duration - temporal_overlap) + + # Compute the total duration of the latent chunks, accounting for overlap + effective_latent_duration = ( + latent_chunk_duration - temporal_overlap_latent + ) * latent_chunks + temporal_overlap_latent + + # Define the shape of the video in the latent space + video_latent_shape = [ + effective_latent_duration, # Temporal dimension + latent_height, # Height in the latent space + latent_width, # Width in the latent space + ] + + # product of video_latent_shape + num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2] + if add_special_tokens: + seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3 + seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64 + # for text to video, we need to add token to indicate the start of the video + elif training_type == "text_to_video": + seq_len = num_token_video_latent + 1 + else: + seq_len = num_token_video_latent + + if seq_len % pad_to_multiple_of != 0: + # Round up to the nearest multiple of pad_to_multiple_of + seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + + # Model size specific parameters + model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=False) + + inference = False # False for training, True for inference + # set_parallel_mode = True + set_parallel_mode = tensor_model_parallel_size > 1 + attention_tp = True + + if context_parallel_size > 1: + assert backend == "transformer_engine", "Context parallelism is only supported in transformer engine." + + if tensor_model_parallel_size > 1: + assert set_parallel_mode, "Tensor model parallelism is only supported in parallel mode." + + # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss + # If num_tokens_to_ignore is specified, use it. + # Else compute it from num_condition_latents_t + if num_tokens_to_ignore < 0: + num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t + if not add_special_tokens and num_condition_latents_t > 0: + # If there are no special tokens (bov), do a -1 so that you can compute the loss + # from the first token of the next chunk + num_tokens_to_ignore -= 1 + + model_config = TrainingModelConfig( + video_height=video_height, + video_width=video_width, + max_seq_len=seq_len, + max_batch_size=batch_size, + inference=inference, + backend=backend, + precision="bfloat16", + ema=EMAConfig(enabled=False), + act_ckpt_enabled=act_ckpt_enabled, + fsdp_enabled=fsdp_enabled, + cache_dir=None, + ckpt_path="checkpoints/Cosmos-Predict1-4B/model.pt", + use_qk_normalization=use_qk_normalization, + vocab_size=64000, + ignore_first_num_tokens=num_tokens_to_ignore, + apply_yarn=apply_yarn, + yarn_beta_fast=yarn_beta_fast, + yarn_beta_slow=yarn_beta_slow, + original_seq_len=original_seq_len, + yarn_scale=yarn_scale, + context_parallel_size=context_parallel_size, + tensor_model_parallel_size=tensor_model_parallel_size, + set_parallel_mode=set_parallel_mode, + attention_tp=attention_tp, + video_latent_shape=video_latent_shape, + num_video_frames=num_video_frames, + rope_dim=rope_dim, + original_latent_shape=original_latent_shape, + pad_to_multiple_of=pad_to_multiple_of, + sequence_parallel=sequence_parallel, + insert_cross_attn=insert_cross_attn, + insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers, + context_dim=context_dim, + finetune_layers_with_cross_attn=finetune_layers_with_cross_attn, + finetune_layers_without_cross_attn=finetune_layers_without_cross_attn, + use_action_condition=use_action_condition, + action_embedding_mode=action_embedding_mode, + action_dim=action_dim, + action_embedding_dim=action_embedding_dim, + group_causal_mask_mode=group_causal_mask_mode, + z_loss_coeff=z_loss_coeff, + embedding_dropout=embedding_dropout, + insert_medusa_head=insert_medusa_head, + ft_medusa_option=ft_medusa_option, + medusa_num_heads=medusa_num_heads, + medusa_num_layers=medusa_num_layers, + medusa_concat_heads=medusa_concat_heads, + fuse_qkv=fuse_qkv, + zero_init_cross_attn_proj=zero_init_cross_attn_proj, + concat_action_to_context=concat_action_to_context, + **model_arch_specs, + ) + + tokenizer_config = TokenizerConfig( + text_tokenizer=None, + video_tokenizer=VideoTokenizerConfig( + config=video_tokenizer_config_creator( + ckpt_path=tokenizer_ckpt_path, pixel_chunk_duration=pixel_chunk_duration + ), + data_key="video", + tokenizer_offset=0, + vocab_size=64000, + tokenize_here=True, + max_seq_len=num_token_video_latent, + temporal_overlap=temporal_overlap, + ), + seq_len="${model.model_config.max_seq_len}", + training_type=training_type, + add_special_tokens=add_special_tokens, + pad_to_multiple_of=pad_to_multiple_of, + ) + + model_parallel = ModelParallelConfig( + bf16=True, + params_dtype=getattr(torch, "bfloat16"), + ) + model_parallel.tensor_model_parallel_size = "${model.model_config.tensor_model_parallel_size}" + model_parallel.context_parallel_size = "${model.model_config.context_parallel_size}" + model_parallel.sequence_parallel = "${model.model_config.sequence_parallel}" + return L(AutoRegressiveTrainingModel.build)( + seed=0, + train_from_scratch=True, + model_config=model_config, + fsdp_checkpointer=None, + tokenizer_config=tokenizer_config, + model_parallel=model_parallel, + shard_checkpoint=shard_checkpoint, + ) diff --git a/cosmos_predict1/autoregressive/configs/base/model_parallel.py b/cosmos_predict1/autoregressive/configs/base/model_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..5f93e257bce3af5fff5e79317d07a3199e7650fd --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/model_parallel.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import ModelParallelConfig + +from cosmos_predict1.utils.lazy_config import LazyDict + + +def create_model_parallel_config(): + model_parallel = ModelParallelConfig(bf16=True, params_dtype=getattr(torch, "bfloat16")) + model_parallel.tensor_model_parallel_size = "${model.model_parallel.tensor_model_parallel_size}" + model_parallel.context_parallel_size = "${model.model_parallel.context_parallel_size}" + model_parallel.sequence_parallel = "${model.model_parallel.sequence_parallel}" + MODEL_PARALLELS = LazyDict( + dict( + model_parallel_bf16=model_parallel, + ), + flags={"allow_objects": True}, + ) + return MODEL_PARALLELS["model_parallel_bf16"] diff --git a/cosmos_predict1/autoregressive/configs/base/optim.py b/cosmos_predict1/autoregressive/configs/base/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..beed4c4959f86d9ce440c90899bd5fd8c8b32cbd --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/optim.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from cosmos_predict1.utils.lazy_config import LazyCall as L + + +class LambdaLinearWarmupScheduler: + """ + A learning rate scheduler that implements linear warm-up and cool-down. + + This scheduler provides three phases: + 1. Warm-up: Learning rate linearly increases from 0 to 1. + 2. Constant: Learning rate remains at 1. + 3. Cool-down: Learning rate linearly decreases from 1 to 0. + + Args: + warmup_steps (int): Number of steps for the warm-up phase. + warmup_offset (int): Starts warmup from this offset. + max_iter (int, optional): Total number of iterations. Required if cooldown_steps is provided. + cooldown_steps (int, optional): Number of steps for the cool-down phase. + + Raises: + ValueError: If cooldown_steps is provided without max_iter, or if an invalid step is given. + """ + + def __init__(self, warmup_steps: int, warmup_offset: int = 0, max_iter: int = None, cooldown_steps: int = None): + self.warmup_steps = warmup_steps + self.warmup_offset = warmup_offset + self.max_iter = max_iter + self.cooldown_steps = cooldown_steps + + if cooldown_steps is not None: + if max_iter is None: + raise ValueError("max_iter must be specified when cooldown_steps is provided") + self.cooldown_start = max_iter - cooldown_steps + else: + self.cooldown_start = None + + def __call__(self, step): + # Warm-up phase + if step < self.warmup_offset: + return 0 + + if step < self.warmup_steps + self.warmup_offset: + return float(step - self.warmup_offset) / float(max(1, self.warmup_steps)) + + # Constant phase (no cool-down) + elif self.cooldown_steps is None: + return 1.0 + + # Constant phase (before cool-down starts) + elif step < self.cooldown_start: + return 1.0 + + # Cool-down phase + elif self.cooldown_start <= step < self.max_iter: + cooldown_progress = (step - self.cooldown_start) / self.cooldown_steps + return 1.0 - cooldown_progress + + # After max_iter + elif step >= self.max_iter: + return 0.0 + + # Unexpected case + else: + raise ValueError(f"Invalid step {step}") + + +LambdaLinearLR = L(torch.optim.lr_scheduler.LambdaLR)( + optimizer=None, + lr_lambda=L(LambdaLinearWarmupScheduler)(warmup_steps=5000), +) diff --git a/cosmos_predict1/autoregressive/configs/base/tokenizer.py b/cosmos_predict1/autoregressive/configs/base/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0e9e81eaccfe86d86411a2e0ef194a4d40a0460b --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/base/tokenizer.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import attrs + +from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQStateDictTokenizer +from cosmos_predict1.autoregressive.tokenizer.networks import CausalDiscreteVideoTokenizer +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def create_discrete_video_fsq_tokenizer_state_dict_config( + ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16] +) -> LazyDict: + CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)( + # The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime. + # - It relies on fully 3D discrete wavelet transform + # - Uses a layer norm instead of a group norm + # - Factorizes full convolutions into spatial and temporal convolutions + # - Factorizes full attention into spatial and temporal attention + # - Strictly causal, with flexible temporal length at inference. + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=16, + temporal_compression=8, + embedding_dim=6, + levels=[8, 8, 8, 5, 5, 5], + name="CausalDiscreteFactorizedVideoTokenizer", + ) + + return L(DiscreteVideoFSQStateDictTokenizer)( + enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"), + dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"), + tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig, + name="discrete_video_fsq", + latent_ch=6, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0], + max_enc_batch_size=8, + max_dec_batch_size=4, + levels=[8, 8, 8, 5, 5, 5], + compression_ratio=compression_ratio, + ) + + +@attrs.define(slots=False) +class TextTokenizerConfig: + """ + Text tokenizer config + + Args: + config: Config file to define the text tokenizer class. + data_key (str): The input key from data_dict that will be passed to the text tokenizer. + tokenize_here (bool): Whether to use the tokenizer to perform online tokenization. + tokenizer_offset (int): Offset that is added to the tokens. + vocab_size (int): Vocabulary size of the tokenizer. + """ + + config: LazyDict + data_key: str = "" + tokenize_here: bool = False + tokenizer_offset: int = 0 + vocab_size: int = 0 + + +@attrs.define(slots=False) +class VideoTokenizerConfig: + """ + Video tokenizer config + + Args: + config: Config file to define the video tokenizer class. + data_key (str): The input key from data_dict that will be passed to the video tokenizer. + tokenize_here (bool): Whether to use the tokenizer to perform online tokenization. + tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we + add an offset to make sure that video tokens and text tokens don't overlap. + vocab_size (int): Vocabulary size of the tokenizer. + max_seq_len (int): Maximum token length for an input video. + temporal_overlap (int): Overlap between consecutive video chunks. + """ + + config: LazyDict + data_key: str = "" + tokenize_here: bool = True + tokenizer_offset: int = 0 + vocab_size: int = 0 + max_seq_len: int = -1 + temporal_overlap: int = 0 + + +@attrs.define(slots=False) +class TokenizerConfig: + """ + Joint tokenizer config + + Args: + text_tokenizer (TextTokenizerConfig): Text tokenizer config file + class_tokenizer (ClassTokenizerConfig): Class tokenizer config file + video_tokenizer (VideoTokenizerConfig): Video tokenizer config file + image_tokenizer (ImageTokenizerConfig): Image tokenizer config file + seq_len (int): Final token sequence length + training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"] + add_special_tokens (bool): Whether to add special tokens to the output tokens + pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64. + """ + + text_tokenizer: Optional[TextTokenizerConfig] = None + video_tokenizer: Optional[VideoTokenizerConfig] = None + seq_len: int = 4096 + training_type: str = None + add_special_tokens: bool = True + pad_to_multiple_of: Optional[int] = 64 diff --git a/cosmos_predict1/autoregressive/configs/config.py b/cosmos_predict1/autoregressive/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..df074434b8128b849e6570d8579e32f121e88ca5 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/config.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default config for cosmos_ar project.""" + +import os +from typing import Any, List + +import attrs + +from cosmos_predict1.autoregressive.configs.registry import register_configs +from cosmos_predict1.autoregressive.trainer import Trainer +from cosmos_predict1.utils import config, log +from cosmos_predict1.utils.config_helper import import_all_modules_from_package + + +@attrs.define(slots=False) +class Config(config.Config): + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"model": None}, + {"data_train": "mock_video"}, + {"data_val": None}, + {"optimizer": "fused_adamw"}, + {"scheduler": "warmup_cosine_lr"}, + {"checkpoint": "local"}, + {"callbacks": "basic"}, + {"global_config": None}, + {"experiment": None}, + ] + ) + + def validate(self) -> None: + """Validate that the config has all required fields.""" + assert self.job.project != "", "job.project is not set" + assert self.job.group != "", "job.group is not set" + assert self.job.name != "", "job.name is not set" + log.info("Validating config for cosmos_autoregressive job") + # FSDP config check + if self.model.model_config.fsdp_enabled: + assert self.trainer.distributed_parallelism == "fsdp" + else: + assert self.trainer.distributed_parallelism == "ddp" + + # Transformer Engine config check + if self.model.model_config.backend == "transformer_engine": + assert ( + "NVTE_FLASH_ATTN" in os.environ and os.environ["NVTE_FLASH_ATTN"] == "1" + ) # Enable Flash attention for transformer engine + + # TP, CP config check + if self.model_parallel is not None: + if self.model_parallel.context_parallel_size > 1: + assert ( + self.model.model_config.backend == "transformer_engine" + ), "Context parallelism is only supported in transformer engine." + + if self.model_parallel.tensor_model_parallel_size > 1: + assert ( + self.model.model_config.set_parallel_mode + ), "Tensor model parallelism is only supported in parallel mode." + + if self.model_parallel.sequence_parallel: + assert ( + self.model_parallel.tensor_model_parallel_size > 1 + ), "Sequence parallelism is only supported in tensor model parallelism." + assert ( + self.model.model_config.backend == "transformer_engine" + ), "Sequence parallelism is only supported in transformer engine." + + +def make_config(): + c = Config( + model=None, + optimizer=None, + scheduler=None, + dataloader_train=None, + dataloader_val=None, + checkpoint=None, + ) + + c.job.project = "cosmos_autoregressive" + c.job.group = "debug" + c.job.name = "default_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + c.trainer.type = Trainer + c.trainer.run_validation = True + + c.trainer.seed = 0 + c.trainer.max_iter = 10 + c.trainer.logging_iter = 1 + + c.trainer.callbacks = None + register_configs() + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.autoregressive.configs.experiment") + return c diff --git a/cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py b/cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py b/cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..7c427968e19669f9e51a97650f38037feb03efd6 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + This file contains a basic configuration for video2video experiments. +""" + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model +from cosmos_predict1.autoregressive.configs.base.model_parallel import create_model_parallel_config +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import LazyDict + +cs = ConfigStore.instance() + + +""" + Finetune 4B model with TP=1, pytorch backend, low resolution tealrobot data, frames 33, chunk 33. + Usage: + torchrun --nproc_per_node=1 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobotsmall_tp1 +""" +base_4b_example_tealrobotsmall_tp1: LazyDict = LazyDict( + dict( + defaults=[ + {"override /data_train": "tealrobot_video_small"}, + { + "override /callbacks": [ + "basic", + "video_teacher_forcing", + ] + }, + {"override /checkpoint": "local"}, + {"override /optimizer": "fused_adamw"}, + {"override /scheduler": "warmup_cosine_lr"}, + "_self_", + ], + job=dict( + project="posttraining", + group="autoregressive_base", + name="base_4b_example_tealrobotsmall_tp1", + ), + model=create_video2world_model( + model_size="4b", + model_family="cosmos", + backend="pytorch", + tensor_model_parallel_size=1, + batch_size=1, + pixel_chunk_duration=33, + num_video_frames=33, + video_height=384, + video_width=640, + tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit", + add_special_tokens=False, + ), + trainer=dict( + max_iter=50000, + grad_accum_iter=1, + grad_scaler_args=dict(enabled=False), + run_validation=False, # No need for validation as epoch <= 1 + distributed_parallelism="ddp", + callbacks=dict( + vid_sampling_tf=dict( + every_n=500, + ), + ), + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Predict1-4B/model.pt", + load_training_state=False, + strict_resume=True, + save_iter=1000, + ), + model_parallel=create_model_parallel_config(), + ), +) + + +""" + Finetune 4B model with TP=4, pytorch backend, high resolution tealrobot data, frame 33, chunk 33. + Usage: + torchrun --nproc_per_node=4 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobot_tp4 +""" +base_4b_example_tealrobot_tp4: LazyDict = LazyDict( + dict( + defaults=[ + {"override /data_train": "tealrobot_video"}, + { + "override /callbacks": [ + "basic", + "video_teacher_forcing", + ] + }, + {"override /checkpoint": "local"}, + {"override /optimizer": "fused_adamw"}, + {"override /scheduler": "warmup_cosine_lr"}, + "_self_", + ], + job=dict( + project="posttraining", + group="autoregressive_base", + name="base_4b_example_tealrobot_tp4", + ), + model=create_video2world_model( + model_size="4b", + model_family="cosmos", + backend="pytorch", + tensor_model_parallel_size=4, + batch_size=1, + pixel_chunk_duration=33, + num_video_frames=33, + video_height=640, + video_width=848, + tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit", + add_special_tokens=False, + ), + trainer=dict( + max_iter=50000, + grad_accum_iter=1, + grad_scaler_args=dict(enabled=False), + run_validation=False, # No need for validation as epoch <= 1 + distributed_parallelism="ddp", + callbacks=dict( + vid_sampling_tf=dict( + every_n=500, + ), + ), + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Predict1-4B/model.pt", + load_training_state=False, + strict_resume=False, + save_iter=1000, + ), + model_parallel=create_model_parallel_config(), + ), +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + base_4b_example_tealrobotsmall_tp1, + base_4b_example_tealrobot_tp4, + ]: + cs.store( + group="experiment", + package="_global_", + name=_item["job"]["name"], + node=_item, + ) diff --git a/cosmos_predict1/autoregressive/configs/inference/inference_config.py b/cosmos_predict1/autoregressive/configs/inference/inference_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b13ffc382b3fe20d237aa4241411cfac5444c353 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/inference/inference_config.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Union + +import attrs + +from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TokenizerConfig + + +@attrs.define(slots=False) +class DataShapeConfig: + latent_shape: list = [] + num_video_frames: Union[None, int] = None + height: Union[None, int] = None + width: Union[None, int] = None + + +@attrs.define(slots=False) +class SamplingConfig: + """ + Sampling config + Args: + temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + logprobs (bool): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + """ + + temperature: float = 0.6 + top_k: int = None + top_p: float = 0.9 + compile_prefill: bool = False + compile_sampling: bool = True + logprobs: bool = False + echo: bool = False + + +@attrs.define(slots=False) +class DiffusionDecoderSamplingConfig: + """ + Diffusion decoder sampling config + Args: + guidance (float): Guidance scale for the diffusion process. Controls how much the model follows the conditioning. Defaults to 0.8. + sigma_min (float): Minimum noise level for the diffusion process. Defaults to 0.02. + sigma (float): Initial noise level for the diffusion process. Defaults to 8. + num_steps (int): Number of denoising steps to perform. Defaults to 35. + overlap (int): Number of overlapping frames between video chunks during processing. Defaults to 2. + continuous_tokenizer_channel (int): Number of channels in the continuous tokenizer of diffusion decoder. Defaults to 16. + continuous_tokenizer_spatial_compression_ratio (int): Spatial compression ratio for the continuous tokenizer of diffusion decoder. Defaults to 8. + dd_train_num_video_frames (int): Number of video frames used during training for diffusion decoder. Defaults to 57. + """ + + guidance: float = 1.8 + sigma_min: float = 0.02 + sigma: float = 8 + num_steps: int = 15 + overlap: int = 2 + continuous_tokenizer_channel = 16 + continuous_tokenizer_spatial_compression_ratio = 8 + dd_train_num_video_frames: int = 57 + max_iter: int = 99 + fps: int = 24 + + +@attrs.define(slots=False) +class InferenceConfig: + """ + Inference config + Args: + model_config (ModelConfig): Model config + tokenizer_config (TokenizerConfig): Tokenizer config + ckpt_path (str): Path to the checkpoint + latent_shape (list): Shape of the latent + """ + + model_config: ModelConfig = None + tokenizer_config: TokenizerConfig = None + ckpt_path: str = "" + data_shape_config: DataShapeConfig = None + + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"data_val": None}, + {"data_shape_config": "video_shape_as_model_config"}, + {"eval_job": None}, + ] + ) diff --git a/cosmos_predict1/autoregressive/configs/registry.py b/cosmos_predict1/autoregressive/configs/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2cdfdcdddc02080d18b1b6d55ac482aac43915 --- /dev/null +++ b/cosmos_predict1/autoregressive/configs/registry.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.autoregressive.configs.base.callbacks import BASIC_CALLBACKS, VIDEO_TEACHER_FORCING_CALLBACK +from cosmos_predict1.autoregressive.configs.base.dataloader import get_tealrobot_video +from cosmos_predict1.autoregressive.configs.base.optim import LambdaLinearLR +from cosmos_predict1.autoregressive.configs.experiment.video2video.basic import register_experiments +from cosmos_predict1.utils import config, log +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.scheduler import WarmupCosineLR + + +def register_checkpoint(cs): + checkpoint_local = config.CheckpointConfig( + save_iter=5000, + broadcast_via_filesystem=True, + ) + cs.store(group="checkpoint", package="checkpoint", name="local", node=checkpoint_local) + + +def register_callbacks(cs): + cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS) + cs.store( + group="callbacks", + package="trainer.callbacks", + name="video_teacher_forcing", + node=VIDEO_TEACHER_FORCING_CALLBACK, + ) + + +def register_scheduler(cs): + cs.store( + group="scheduler", + package="scheduler", + name="warmup_cosine_lr", + node=L(WarmupCosineLR)(optimizer=None, warmup_iters=5000, lr_decay_iters="${trainer.max_iter}", min_lr=1e-8), + ) + cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearLR) + + +def register_optimizer(cs): + cs.store( + group="optimizer", + package="optimizer", + name="fused_adamw", + node=L(torch.optim.AdamW)(params=None, lr=1e-3, weight_decay=0.05, fused=True), + ) + cs.store( + group="optimizer", + package="optimizer", + name="sgd", + node=L(torch.optim.SGD)(params=None, lr=5e-6, momentum=0.9), + ) + + +def register_training_data(cs): + cs.store( + group="data_train", + package="dataloader_train", + name="tealrobot_video_small", + node=get_tealrobot_video(num_frames=33, video_size=[384, 640]), + ) + cs.store(group="data_train", package="dataloader_train", name="tealrobot_video", node=get_tealrobot_video()) + + +def register_configs(): + log.info("Registering configs for autoregressive_base") + cs = ConfigStore.instance() + register_callbacks(cs) + register_checkpoint(cs) + register_optimizer(cs) + register_scheduler(cs) + register_training_data(cs) + register_experiments(cs) diff --git a/cosmos_predict1/autoregressive/datasets/dataset_utils.py b/cosmos_predict1/autoregressive/datasets/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e13360c351f6bc614b346e9c64420af39c088a4 --- /dev/null +++ b/cosmos_predict1/autoregressive/datasets/dataset_utils.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch +import torchvision.transforms.functional as transforms_F +from PIL import Image + + +def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]: + r"""Function for obtaining the image size from the data dict. + + Args: + data_dict (dict): Input data dict + input_keys (list): List of input keys + Returns: + width (int): Width of the input image + height (int): Height of the input image + """ + + data1 = data_dict[input_keys[0]] + if isinstance(data1, Image.Image): + width, height = data1.size + elif isinstance(data1, torch.Tensor): + height, width = data1.size()[-2:] + else: + raise ValueError("data to random crop should be PIL Image or tensor") + + return width, height + + +class Augmentor: + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + r"""Base augmentor class + + Args: + input_keys (list): List of input keys + output_keys (list): List of output keys + args (dict): Arguments associated with the augmentation + """ + self.input_keys = input_keys + self.output_keys = output_keys + self.args = args + + def __call__(self, *args: Any, **kwds: Any) -> Any: + raise ValueError("Augmentor not implemented") + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_w, img_h = self.args["img_w"], self.args["img_h"] + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] >= img_h and target_size[1] >= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {(img_w, img_h)} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class CenterCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs center crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + assert ( + (self.args is not None) and ("img_w" in self.args) and ("img_h" in self.args) + ), "Please specify size in args" + + img_w, img_h = self.args["img_w"], self.args["img_h"] + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [img_h, img_w]) + + # We also add the aug params we use. This will be useful for other transforms + crop_x0 = (orig_w - img_w) // 2 + crop_y0 = (orig_h - img_h) // 2 + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": img_w, + "crop_h": img_h, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + data_dict["padding_mask"] = torch.zeros((1, cropping_params["crop_h"], cropping_params["crop_w"])) + return data_dict + + +class Normalize(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + + mean = self.args["mean"] + std = self.args["std"] + + for key in self.input_keys: + if isinstance(data_dict[key], torch.Tensor): + data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255) + else: + data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor() + + data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std) + return data_dict diff --git a/cosmos_predict1/autoregressive/datasets/video_dataset.py b/cosmos_predict1/autoregressive/datasets/video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e129ed4b4a0d1cbae5297a31ba4286a6ae259b8a --- /dev/null +++ b/cosmos_predict1/autoregressive/datasets/video_dataset.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/autoregressive/datasets/video_dataset.py +""" + +import os +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +import torch +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from tqdm import tqdm + +from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig +from cosmos_predict1.autoregressive.datasets.dataset_utils import ( + CenterCrop, + Normalize, + ResizeSmallestSideAspectPreserving, +) + + +class VideoDataset(Dataset): + def __init__(self, config: VideoDatasetConfig): + """Video Dataset class for loading video-to-video generation data.""" + + super().__init__() + self.dataset_dir = config.dataset_dir + self.sequence_interval = config.sequence_interval + self.sequence_length = config.num_frames + self.video_size = config.video_size + self.start_frame_interval = config.start_frame_interval + + self.video_dir = self.dataset_dir + self.video_paths = [os.path.join(self.video_dir, f) for f in os.listdir(self.video_dir) if f.endswith(".mp4")] + print(f"{len(self.video_paths)} videos in total") + + self.samples = self._init_samples(self.video_paths) + self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) + print(f"{len(self.samples)} samples in total") + self.wrong_number = 0 + + self.resize_transform = ResizeSmallestSideAspectPreserving( + input_keys=["video"], + args={"img_w": self.video_size[1], "img_h": self.video_size[0]}, + ) + self.crop_transform = CenterCrop( + input_keys=["video"], + args={"img_w": self.video_size[1], "img_h": self.video_size[0]}, + ) + self.normalize_transform = Normalize( + input_keys=["video"], + args={"mean": 0.5, "std": 0.5}, + ) + + def __str__(self): + return f"{len(self.video_paths)} samples from {self.dataset_dir}" + + def _init_samples(self, video_paths): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_video_path = { + executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths + } + for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): + samples.extend(future.result()) + return samples + + def _load_and_process_video_path(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + n_frames = len(vr) + + samples = [] + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["video_path"] = video_path + sample["orig_num_frames"] = n_frames + sample["chunk_index"] = -1 + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + sample["chunk_index"] += 1 + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all(), "Some frame_ids are out of range." + assert (np.array(frame_ids) >= 0).all(), "Some frame_ids are negative." + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + fps = vr.get_avg_fps() + return frame_data, fps + + def _get_frames(self, video_path, frame_ids): + frames, fps = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames) + frames = frames.permute(0, 3, 1, 2) # Rearrange from [T, H, W, C] to [T, C, H, W] + return frames, fps + + def __getitem__(self, index): + try: + sample = self.samples[index] + video_path = sample["video_path"] + frame_ids = sample["frame_ids"] + + data = dict() + + video, fps = self._get_frames(video_path, frame_ids) + data["video"] = video + data["fps"] = fps + data["num_frames"] = self.sequence_length + data["orig_num_frames"] = sample["orig_num_frames"] + data["chunk_index"] = sample["chunk_index"] + data["frame_start"] = frame_ids[0] + data["frame_end"] = frame_ids[-1] + + data["video_name"] = { + "video_path": video_path, + "start_frame_id": str(frame_ids[0]), + } + + # resize video to smallest side aspect preserving + data = self.resize_transform(data) + # center crop video + data = self.crop_transform(data) + # normalize video + data = self.normalize_transform(data) + + data["video"] = data["video"].permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + config = VideoDatasetConfig(dataset_dir="datasets/cosmos_nemo_assets/videos/") + dataset = VideoDataset(config) + + indices = [0, 1, 2, -1] + for idx in indices: + data = dataset[idx] + print( + ( + f"{idx=} " + f"{data['video'].sum()=}\n" + f"{data['video'].shape=}\n" + f"{data['video_name']=}\n" + f"{data.keys()=}\n" + "---" + ) + ) diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/__init__.py b/cosmos_predict1/autoregressive/diffusion_decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py b/cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..1d2e21c745a21193949db41c45b5654c24156c4c --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Optional + +import torch + +from cosmos_predict1.diffusion.conditioner import BaseVideoCondition, GeneralConditioner +from cosmos_predict1.diffusion.config.base.conditioner import ( + FPSConfig, + ImageSizeConfig, + LatentConditionConfig, + LatentConditionSigmaConfig, + NumFramesConfig, + PaddingMaskConfig, + TextConfig, +) +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@dataclass +class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): + # latent_condition will concat to the input of network, along channel dim; + # cfg will make latent_condition all zero padding. + latent_condition: Optional[torch.Tensor] = None + latent_condition_sigma: Optional[torch.Tensor] = None + + +class VideoDiffusionDecoderConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoLatentDiffusionDecoderCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoLatentDiffusionDecoderCondition(**output) + + +VideoLatentDiffusionDecoderConditionerConfig: LazyDict = L(VideoDiffusionDecoderConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + latent_condition=LatentConditionConfig(), + latent_condition_sigma=LatentConditionSigmaConfig(), +) diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py b/cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c675081114d8b0dcc5d27f42321c98ed8decd78f --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from cosmos_predict1.autoregressive.diffusion_decoder.config.registry import register_configs as register_dd_configs +from cosmos_predict1.diffusion.config.base.model import LatentDiffusionDecoderModelConfig +from cosmos_predict1.diffusion.config.registry import register_configs +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config_helper import import_all_modules_from_package + + +@attrs.define(slots=False) +class Config(config.Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"net": None}, + {"conditioner": "basic"}, + {"tokenizer": "tokenizer"}, + {"tokenizer_corruptor": None}, + {"latent_corruptor": None}, + {"pixel_corruptor": None}, + {"experiment": None}, + ] + ) + + +def make_config(): + c = Config(model=LatentDiffusionDecoderModelConfig()) + + # Specifying values through instances of attrs + c.job.project = "cosmos_video4" + c.job.group = "debug" + c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + # Call this function to register config groups for advanced overriding. + register_configs() + register_dd_configs() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.diffusion.config.inference", reload=True) + import_all_modules_from_package("cosmos_predict1.autoregressive.diffusion_decoder.config.inference", reload=True) + return c diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py b/cosmos_predict1/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py new file mode 100644 index 0000000000000000000000000000000000000000..308232872f98f300922374eb030b999866d5b3dc --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.autoregressive.diffusion_decoder.network import DiffusionDecoderGeneralDIT +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +num_frames = 57 +Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /tokenizer": "cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624"}, + {"override /conditioner": "video_latent_diffusion_decoder_cond"}, + {"override /tokenizer_corruptor": "cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224"}, + "_self_", + ], + job=dict( + group="diffusion_deocder_FT_7Bv1_001", + name="DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token", + ), + model=dict( + diffusion_decoder_cond_sigma_low=0.0, + diffusion_decoder_cond_sigma_high=0.0, + diffusion_decoder_corrupt_prob=0.0, + condition_on_tokenizer_corruptor_token=True, + latent_shape=[ + 16, + num_frames, + 88, + 160, + ], + tokenizer_corruptor=dict( + pixel_chunk_duration=num_frames, + latent_chunk_duration=1 + (num_frames - 1) // 8, + ), + net=L(DiffusionDecoderGeneralDIT)( + diffusion_decoder_condition_on_sigma=False, + max_img_h=240, + max_img_w=240, + rope_h_extrapolation_ratio=1.5, + rope_w_extrapolation_ratio=1.5, + rope_t_extrapolation_ratio=1, + block_x_format="THWBD", + is_diffusion_decoder=True, + patch_spatial=2, + diffusion_decoder_condition_on_token=True, + diffusion_decoder_token_condition_voc_size=64000, + diffusion_decoder_token_condition_dim=32, + ), + tokenizer=dict( + video_vae=dict( + pixel_chunk_duration=num_frames, + ) + ), + conditioner=dict( + latent_condition=dict( + dropout_rate=0.2, + ) + ), + ), + ) +) + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY["job"]["name"], + node=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY, +) diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/config/registry.py b/cosmos_predict1/autoregressive/diffusion_decoder/config/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..fbcf6e3310394eb25d352fe453d23e0a4dcc2bdc --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/config/registry.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.autoregressive.diffusion_decoder.config.base.conditioner import ( + VideoLatentDiffusionDecoderConditionerConfig, +) +from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQJITTokenizer +from cosmos_predict1.diffusion.module.pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer +from cosmos_predict1.utils.lazy_config import LazyCall as L + + +def get_cosmos_video_discrete_tokenizer_comp8x16x16( + resolution: str, + chunk_duration: int, + checkpoint_path: str, +): + assert resolution in ["720"] + + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 16 + + return L(DiscreteVideoFSQJITTokenizer)( + enc_fp=checkpoint_path.replace(".jit", "encoder.jit"), + dec_fp=checkpoint_path.replace(".jit", "decoder.jit"), + name="discrete_video_fsq", + latent_ch=6, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + latent_chunk_duration=1 + (pixel_chunk_duration - 1) // temporal_compression_factor, + max_enc_batch_size=8, + max_dec_batch_size=4, + levels=[8, 8, 8, 5, 5, 5], + compression_ratio=[temporal_compression_factor, spatial_compression_factor, spatial_compression_factor], + ) + + +def get_cosmos_video_tokenizer_comp8x8x8(resolution: str, chunk_duration: int, checkpoint_path=None): + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 8 + + return L(JointImageVideoSharedJITTokenizer)( + video_vae=L(VideoJITTokenizer)( + name="cosmos_predict1_tokenizer", + latent_ch=16, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + temporal_compression_factor=temporal_compression_factor, + spatial_compression_factor=spatial_compression_factor, + spatial_resolution=resolution, + ), + image_vae=L(JITVAE)( + name="cosmos_predict1_tokenizer", + latent_ch=16, + is_image=False, + is_bf16=True, + ), + name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624", + latent_ch=16, + ) + + +def register_tokenizer(cs): + cs.store( + group="tokenizer", + package="model.tokenizer", + name="cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624", + node=get_cosmos_video_tokenizer_comp8x8x8( + resolution="720", + chunk_duration=121, + checkpoint_path="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/.jit", + ), + ) + + +def register_corruptor(cs): + cs.store( + group="tokenizer_corruptor", + package="model.tokenizer_corruptor", + name="cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224", + node=get_cosmos_video_discrete_tokenizer_comp8x16x16( + resolution="720", + chunk_duration=49, + checkpoint_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/.jit", + ), + ) + + +def register_conditioner(cs): + cs.store( + group="conditioner", + package="model.conditioner", + name="video_latent_diffusion_decoder_cond", + node=VideoLatentDiffusionDecoderConditionerConfig, + ) + + +def register_configs(): + cs = ConfigStore.instance() + + register_conditioner(cs) + register_corruptor(cs) + register_tokenizer(cs) diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/inference.py b/cosmos_predict1/autoregressive/diffusion_decoder/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4b7dbdb256dabbbe383fd31eee9c92e1b936d8ce --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/inference.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import gc +from typing import List + +import torch + +from cosmos_predict1.autoregressive.configs.inference.inference_config import DiffusionDecoderSamplingConfig +from cosmos_predict1.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel +from cosmos_predict1.autoregressive.diffusion_decoder.utils import linear_blend_video_list, split_with_overlap +from cosmos_predict1.utils import log + + +def diffusion_decoder_process_tokens( + model: LatentDiffusionDecoderModel, + indices_tensor: List[torch.Tensor], + dd_sampling_config: DiffusionDecoderSamplingConfig = None, + original_video_example: torch.Tensor = None, + t5_emb_batch: List[torch.Tensor] = None, +): + _, T, H, W = original_video_example.shape + if dd_sampling_config is None: + dd_sampling_config = DiffusionDecoderSamplingConfig() + # indices_tensor is assumed to be a list of tensors with shape 1LHW + data_batch_list = [] + for sample_num, token_CTHW in enumerate(indices_tensor): + token_BCTHW = token_CTHW.unsqueeze(0).unsqueeze(1) + token_BCTHW = split_with_overlap( + token_BCTHW, + (dd_sampling_config.dd_train_num_video_frames - 1) // 8 + 1, + overlap=dd_sampling_config.overlap, + tobf16=False, + ) + data_batch_list.append( + { + "token_chunks": token_BCTHW, + "t5_text_embeddings": t5_emb_batch[sample_num].to(torch.bfloat16), + "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), + # other conditions + "image_size": torch.tensor([[H, W, H, W]] * 1, dtype=torch.bfloat16).cuda(), + "fps": torch.tensor([dd_sampling_config.fps] * 1, dtype=torch.bfloat16).cuda(), + "num_frames": torch.tensor( + [dd_sampling_config.dd_train_num_video_frames] * 1, dtype=torch.bfloat16 + ).cuda(), + "padding_mask": torch.zeros((1, 1, H, W), dtype=torch.bfloat16).cuda(), + } + ) + + out_videos_batch = [] + + for idx, data_batch_template in enumerate(data_batch_list): + full_length_sample = [] + iterations = min(len(data_batch_template["token_chunks"]), dd_sampling_config.max_iter) + for iter in range(iterations): + gc.collect() + torch.cuda.empty_cache() + + data_batch = copy.deepcopy(data_batch_template) + data_batch["video"] = data_batch_template["token_chunks"][iter].cuda().to("cuda") + + log.debug(f"Run iter {iter} for video # {idx} at length {data_batch['video'].shape[2]}") + # org_video, + with torch.no_grad(): + samples_latent = model.generate_samples_from_batch( + data_batch, + guidance=dd_sampling_config.guidance, + state_shape=[ + dd_sampling_config.continuous_tokenizer_channel, + dd_sampling_config.continuous_tokenizer_spatial_compression_ratio, + H // 8, + W // 8, + ], + apply_corruptor=False, + preencode_condition=True, # We are using discrete model, so the input is already pre-encoded + num_steps=dd_sampling_config.num_steps, + ) + log.debug(f"Current sample shape {samples_latent.shape} for video # {idx} ") + full_length_sample.append(samples_latent.detach()) + + # Turn off because we remove CP + # distributed.barrier() + del data_batch + + torch.cuda.empty_cache() + + gc.collect() + torch.cuda.empty_cache() + + # Decode full-length samples and free GPU memory + full_length_sample_pixs = [model.decode(item).clamp(-1, 1).cpu() for item in full_length_sample] + torch.cuda.empty_cache() + + # Blend pixel samples + if len(full_length_sample_pixs) > 1: + full_length_sample_pixel_blend = linear_blend_video_list( + full_length_sample_pixs, dd_sampling_config.overlap + )[:, :, :T] + else: + full_length_sample_pixel_blend = full_length_sample_pixs[0][:, :, :T] + + # Batch size of full_length_sample_pixel_blend is always 1 + out_videos_batch.append((1 + full_length_sample_pixel_blend[0].cpu()) / 2) + return out_videos_batch diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/model.py b/cosmos_predict1/autoregressive/diffusion_decoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..93c474bfee40441c1bbed7feedcbe4c3199e6e4f --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/model.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch +from diffusers import EDMEulerScheduler +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import BaseVideoCondition +from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel +from cosmos_predict1.diffusion.module import parallel +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate + + +@dataclass +class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): + # latent_condition will concat to the input of network, along channel dim; + # cfg will make latent_condition all zero padding. + latent_condition: Optional[torch.Tensor] = None + latent_condition_sigma: Optional[torch.Tensor] = None + + +class LatentDiffusionDecoderModel(DiffusionT2WModel): + def __init__(self, config): + super().__init__(config) + """ + latent_corruptor: the corruption module is used to corrupt the latents. It add gaussian noise to the latents. + pixel_corruptor: the corruption module is used to corrupt the pixels. It apply gaussian blur kernel to pixels in a temporal consistent way. + tokenizer_corruptor: the corruption module is used to simulate tokenizer reconstruction errors. + + diffusion decoder noise augmentation pipeline for continuous token condition model: + condition: GT_video [T, H, W] + -> tokenizer_corruptor~(8x8x8) encode -> latent_corruptor -> tokenizer_corruptor~(8x8x8) decode + -> pixel corruptor + -> tokenizer~(1x8x8) encode -> condition [T, H/8, W/8] + GT: GT_video [T, H, W] -> tokenizer~(1x8x8) -> x_t [T, H/8, W/8]. + + diffusion decoder noise augmentation pipeline for discrete token condition model: + condition: GT_video [T, H, W] + -> pixel corruptor + -> discrete tokenizer encode -> condition [T, T/8, H/16, W/16] + GT: GT_video [T, H, W] -> tokenizer~(8x8x8) -> x_t [T, T/8, H/8, W/8]. + + """ + self.latent_corruptor = lazy_instantiate(config.latent_corruptor) + self.pixel_corruptor = lazy_instantiate(config.pixel_corruptor) + self.tokenizer_corruptor = lazy_instantiate(config.tokenizer_corruptor) + + if self.latent_corruptor: + self.latent_corruptor.to(**self.tensor_kwargs) + if self.pixel_corruptor: + self.pixel_corruptor.to(**self.tensor_kwargs) + + if self.tokenizer_corruptor: + if hasattr(self.tokenizer_corruptor, "reset_dtype"): + self.tokenizer_corruptor.reset_dtype() + else: + assert self.pixel_corruptor is not None + + self.diffusion_decoder_cond_sigma_low = config.diffusion_decoder_cond_sigma_low + self.diffusion_decoder_cond_sigma_high = config.diffusion_decoder_cond_sigma_high + self.diffusion_decoder_corrupt_prob = config.diffusion_decoder_corrupt_prob + if hasattr(config, "condition_on_tokenizer_corruptor_token"): + self.condition_on_tokenizer_corruptor_token = config.condition_on_tokenizer_corruptor_token + else: + self.condition_on_tokenizer_corruptor_token = False + + self.scheduler = EDMEulerScheduler(sigma_max=80, sigma_min=0.02, sigma_data=self.sigma_data) + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = 1, + is_negative_prompt: bool = False, + num_steps: int = 35, + apply_corruptor: bool = False, + corrupt_sigma: float = 0.01, + preencode_condition: bool = False, + ) -> Tensor: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + preencode_condition (bool): use pre-computed condition if true, save tokenizer's inference time memory/ + """ + if not preencode_condition: + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + if n_sample is None: + n_sample = data_batch[self.input_data_key].shape[0] + + condition, uncondition = self._get_conditions( + data_batch, + is_negative_prompt=is_negative_prompt, + apply_corruptor=apply_corruptor, + corrupt_sigma=corrupt_sigma, + preencode_condition=preencode_condition, + ) + + self.scheduler.set_timesteps(num_steps) + + xt = torch.randn(size=(n_sample,) + tuple(state_shape)) * self.scheduler.init_noise_sigma + + to_cp = self.net.is_context_parallel_enabled + if to_cp: + xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) + + for t in self.scheduler.timesteps: + xt = xt.to(**self.tensor_kwargs) + xt_scaled = self.scheduler.scale_model_input(xt, timestep=t) + # Predict the noise residual + t = t.to(**self.tensor_kwargs) + net_output_cond = self.net(x=xt_scaled, timesteps=t, **condition.to_dict()) + net_output_uncond = self.net(x=xt_scaled, timesteps=t, **uncondition.to_dict()) + net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) + # Compute the previous noisy sample x_t -> x_t-1 + xt = self.scheduler.step(net_output, t, xt).prev_sample + samples = xt + + if to_cp: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + return samples + + def _get_conditions( + self, + data_batch: dict, + is_negative_prompt: bool = False, + apply_corruptor: bool = True, + corrupt_sigma: float = 1.5, + preencode_condition: bool = False, + ): + """Get the conditions for the model. + + Args: + data_batch: Input data dictionary + is_negative_prompt: Whether to use negative prompting + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + add_input_frames_guidance: Whether to apply guidance to input frames + + Returns: + condition: Input conditions + uncondition: Conditions removed/reduced to minimum (unconditioned) + """ + self._add_latent_conditions_to_data_batch( + data_batch, + apply_corruptor=apply_corruptor, + corrupt_sigma=corrupt_sigma, + preencode_condition=preencode_condition, + ) + + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + # For inference, check if parallel_state is initialized + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) + + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + condition.latent_condition = split_inputs_cp(condition.latent_condition, seq_dim=2, cp_group=cp_group) + condition.latent_condition_sigma = split_inputs_cp( + condition.latent_condition_sigma, seq_dim=2, cp_group=cp_group + ) + uncondition.latent_condition = split_inputs_cp(uncondition.latent_condition, seq_dim=2, cp_group=cp_group) + uncondition.latent_condition_sigma = split_inputs_cp( + uncondition.latent_condition_sigma, seq_dim=2, cp_group=cp_group + ) + return condition, uncondition + + def _add_latent_conditions_to_data_batch( + self, + data_batch: dict, + apply_corruptor: bool = True, + corrupt_sigma: float = 1.5, + preencode_condition: bool = False, + ): + # Latent state + raw_state = data_batch[self.input_data_key] + + if self.condition_on_tokenizer_corruptor_token: + if preencode_condition: + latent_condition = raw_state.to(torch.int32).contiguous() + corrupted_pixel = self.tokenizer_corruptor.decode(latent_condition[:, 0]) + else: + corrupted_pixel = ( + self.pixel_corruptor(raw_state) if apply_corruptor and self.pixel_corruptor else raw_state + ) + latent_condition = self.tokenizer_corruptor.encode(corrupted_pixel) + latent_condition = latent_condition[1] if isinstance(latent_condition, tuple) else latent_condition + corrupted_pixel = self.tokenizer_corruptor.decode(latent_condition) + latent_condition = latent_condition.unsqueeze(1) + else: + if preencode_condition: + latent_condition = raw_state + corrupted_pixel = self.decode(latent_condition) + else: + corrupted_pixel = ( + self.pixel_corruptor(raw_state) if apply_corruptor and self.pixel_corruptor else raw_state + ) + latent_condition = self.encode(corrupted_pixel).contiguous() + + sigma = ( + torch.rand((latent_condition.shape[0],)).to(**self.tensor_kwargs) * corrupt_sigma + ) # small value to indicate clean video + c_noise_cond = self.scheduler.precondition_noise(sigma=sigma) + if corrupt_sigma != self.diffusion_decoder_cond_sigma_low and self.diffusion_decoder_corrupt_prob > 0: + sigma_expand = sigma.view((-1,) + (1,) * (latent_condition.dim() - 1)) + noise = sigma_expand * torch.randn_like(latent_condition) + latent_condition = latent_condition + noise + data_batch["latent_condition_sigma"] = torch.ones_like(latent_condition[:, 0:1, ::]) * c_noise_cond + data_batch["latent_condition"] = latent_condition + + +def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: + condition_kwargs = {} + for k, v in condition.to_dict().items(): + if isinstance(v, torch.Tensor): + assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" + condition_kwargs[k] = parallel.broadcast(v, to_tp=to_tp, to_cp=to_cp) + condition = type(condition)(**condition_kwargs) + return condition diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/network.py b/cosmos_predict1/autoregressive/diffusion_decoder/network.py new file mode 100644 index 0000000000000000000000000000000000000000..7ca7b372b0d8f340fae6017d4e99b15c96a4d874 --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/network.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks +from torchvision import transforms + +from cosmos_predict1.diffusion.module.blocks import PatchEmbed +from cosmos_predict1.diffusion.networks.general_dit import GeneralDIT +from cosmos_predict1.utils import log + + +class DiffusionDecoderGeneralDIT(GeneralDIT): + def __init__( + self, + *args, + is_diffusion_decoder: bool = True, + diffusion_decoder_condition_on_sigma: bool = False, + diffusion_decoder_condition_on_token: bool = False, + diffusion_decoder_token_condition_voc_size: int = 64000, + diffusion_decoder_token_condition_dim: int = 32, + **kwargs, + ): + # diffusion decoder setting + self.is_diffusion_decoder = is_diffusion_decoder + self.diffusion_decoder_condition_on_sigma = diffusion_decoder_condition_on_sigma + self.diffusion_decoder_condition_on_token = diffusion_decoder_condition_on_token + self.diffusion_decoder_token_condition_voc_size = diffusion_decoder_token_condition_voc_size + self.diffusion_decoder_token_condition_dim = diffusion_decoder_token_condition_dim + super().__init__(*args, **kwargs) + + def initialize_weights(self): + # Initialize transformer layers: + super().initialize_weights() + if self.diffusion_decoder_condition_on_token: + nn.init.constant_(self.token_embedder.weight, 0) + + @property + def is_context_parallel_enabled(self): + return self.cp_group is not None + + def enable_context_parallel(self, cp_group: ProcessGroup): + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + # Set these attributes for spliting the data after embedding. + self.cp_group = cp_group + # Set these attributes for computing the loss. + self.cp_size = cp_size + + # self.pos_embedder.enable_context_parallel(cp_group) + self.pos_embedder.cp_group = cp_group + + if self.extra_per_block_abs_pos_emb: + # self.extra_pos_embedder.enable_context_parallel(cp_group) + self.extra_pos_embedder.cp_group = cp_group + + # Loop through the model to set up context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff", "cross_attn", "ca"]: + continue + elif layer.block.attn.backend == "transformer_engine": + layer.block.attn.attn_op.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + + log.debug(f"[CP] Enable context parallelism with size {cp_size}") + + def disable_context_parallel(self): + self.cp_group = None + self.cp_size = None + + self.pos_embedder.disable_context_parallel() + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.disable_context_parallel() + + # Loop through the model to disable context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff"]: + continue + elif layer.block_type in ["cross_attn", "ca"]: + continue + else: + layer.block.attn.attn_op.cp_group = None + layer.block.attn.attn_op.cp_ranks = None + + log.debug("[CP] Disable context parallelism.") + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + is_diffusion_decoder, + diffusion_decoder_token_condition_dim, + diffusion_decoder_condition_on_sigma, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + self.is_diffusion_decoder, + self.diffusion_decoder_token_condition_dim, + self.diffusion_decoder_condition_on_sigma, + ) + in_channels = ( + in_channels + in_channels + if (is_diffusion_decoder and not self.diffusion_decoder_condition_on_token) + else in_channels + ) + in_channels = in_channels + 1 if diffusion_decoder_condition_on_sigma else in_channels + in_channels = ( + in_channels + self.diffusion_decoder_token_condition_dim + if self.diffusion_decoder_condition_on_token + else in_channels + ) + in_channels = in_channels + 1 if concat_padding_mask else in_channels + + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + ) + + if self.diffusion_decoder_condition_on_token: + self.token_embedder = nn.Embedding( + self.diffusion_decoder_token_condition_voc_size, self.diffusion_decoder_token_condition_dim + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.diffusion_decoder_condition_on_token: + latent_condition = self.token_embedder(latent_condition) + B, _, T, H, W, _ = latent_condition.shape + latent_condition = rearrange(latent_condition, "B 1 T H W D -> (B T) (1 D) H W") + + latent_condition = transforms.functional.resize( + latent_condition, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.BILINEAR + ) + latent_condition = rearrange(latent_condition, "(B T) D H W -> B D T H W ", B=B, T=T) + x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, latent_condition], dim=1) + if self.diffusion_decoder_condition_on_sigma: + x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, latent_condition_sigma], dim=1) + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb diff --git a/cosmos_predict1/autoregressive/diffusion_decoder/utils.py b/cosmos_predict1/autoregressive/diffusion_decoder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c584c7c9a5e03bcb3b808d053f89e7c2aeaf9cf --- /dev/null +++ b/cosmos_predict1/autoregressive/diffusion_decoder/utils.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +def split_with_overlap(video_BCTHW, num_video_frames, overlap=2, tobf16=True): + """ + Splits the video tensor into chunks of num_video_frames with a specified overlap. + + Args: + - video_BCTHW (torch.Tensor): Input tensor with shape [Batch, Channels, Time, Height, Width]. + - num_video_frames (int): Number of frames per chunk. + - overlap (int): Number of overlapping frames between chunks. + + Returns: + - List of torch.Tensors: List of video chunks with overlap. + """ + # Get the dimensions of the input tensor + B, C, T, H, W = video_BCTHW.shape + + # Ensure overlap is less than num_video_frames + assert overlap < num_video_frames, "Overlap should be less than num_video_frames." + + # List to store the chunks + chunks = [] + + # Step size for the sliding window + step = num_video_frames - overlap + + # Loop through the time dimension (T) with the sliding window + for start in range(0, T - overlap, step): + end = start + num_video_frames + # Handle the case when the last chunk might go out of bounds + if end > T: + # Get the last available frame + num_padding_frames = end - T + chunk = F.pad(video_BCTHW[:, :, start:T, :, :], (0, 0, 0, 0, 0, num_padding_frames), mode="reflect") + else: + # Regular case: no padding needed + chunk = video_BCTHW[:, :, start:end, :, :] + if tobf16: + chunks.append(chunk.to(torch.bfloat16)) + else: + chunks.append(chunk) + return chunks + + +def linear_blend_video_list(videos, D): + """ + Linearly blends a list of videos along the time dimension with overlap length D. + + Parameters: + - videos: list of video tensors, each of shape [b, c, t, h, w] + - D: int, overlap length + + Returns: + - output_video: blended video tensor of shape [b, c, L, h, w] + """ + assert len(videos) >= 2, "At least two videos are required." + b, c, t, h, w = videos[0].shape + N = len(videos) + + # Ensure all videos have the same shape + for video in videos: + assert video.shape == (b, c, t, h, w), "All videos must have the same shape." + + # Calculate total output length + L = N * t - D * (N - 1) + output_video = torch.zeros((b, c, L, h, w), device=videos[0].device) + + output_index = 0 # Current index in the output video + + for i in range(N): + if i == 0: + # Copy frames from the first video up to t - D + output_video[:, :, output_index : output_index + t - D, :, :] = videos[i][:, :, : t - D, :, :] + output_index += t - D + else: + # Blend overlapping frames between videos[i-1] and videos[i] + blend_weights = torch.linspace(0, 1, steps=D, device=videos[0].device) + + for j in range(D): + w1 = 1 - blend_weights[j] + w2 = blend_weights[j] + frame_from_prev = videos[i - 1][:, :, t - D + j, :, :] + frame_from_curr = videos[i][:, :, j, :, :] + output_frame = w1 * frame_from_prev + w2 * frame_from_curr + output_video[:, :, output_index, :, :] = output_frame + output_index += 1 + + if i < N - 1: + # Copy non-overlapping frames from current video up to t - D + frames_to_copy = t - 2 * D + if frames_to_copy > 0: + output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][ + :, :, D : t - D, :, : + ] + output_index += frames_to_copy + else: + # For the last video, copy frames from D to t + frames_to_copy = t - D + output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][:, :, D:, :, :] + output_index += frames_to_copy + + return output_video diff --git a/cosmos_predict1/autoregressive/inference/__init__.py b/cosmos_predict1/autoregressive/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/inference/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/inference/base.py b/cosmos_predict1/autoregressive/inference/base.py new file mode 100644 index 0000000000000000000000000000000000000000..214836c6ada2eaf8dd9034177b15400fb5eb893a --- /dev/null +++ b/cosmos_predict1/autoregressive/inference/base.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import imageio +import torch + +from cosmos_predict1.autoregressive.inference.world_generation_pipeline import ARBaseGenerationPipeline +from cosmos_predict1.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args +from cosmos_predict1.utils import log + + +def parse_args(): + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + parser.add_argument( + "--ar_model_dir", + type=str, + default="Cosmos-Predict1-4B", + ) + parser.add_argument("--input_type", type=str, default="video", help="Type of input", choices=["image", "video"]) + args = parser.parse_args() + return args + + +def main(args): + """Run video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple images/videos from input + - Generating videos from images/videos + - Saving the generated videos to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (temperature, top_p) + - Input/output settings (images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + inference_type = "base" # When the inference_type is "base", AR model does not take text as input, the world generation is purely based on the input video + sampling_config = validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + + # Initialize base generation model pipeline + pipeline = ARBaseGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.ar_model_dir, + disable_diffusion_decoder=args.disable_diffusion_decoder, + offload_guardrail_models=args.offload_guardrail_models, + offload_diffusion_decoder=args.offload_diffusion_decoder, + offload_network=args.offload_ar_model, + offload_tokenizer=args.offload_tokenizer, + disable_guardrail=args.disable_guardrail, + parallel_size=args.num_gpus, + ) + + # Load input image(s) or video(s) + input_videos = load_vision_input( + input_type=args.input_type, + batch_input_path=args.batch_input_path, + input_image_or_video_path=args.input_image_or_video_path, + data_resolution=args.data_resolution, + num_input_frames=args.num_input_frames, + ) + + for idx, input_filename in enumerate(input_videos): + inp_vid = input_videos[input_filename] + # Generate video + log.info(f"Run with image or video path: {input_filename}") + out_vid = pipeline.generate( + inp_vid=inp_vid, + num_input_frames=args.num_input_frames, + seed=args.seed, + sampling_config=sampling_config, + ) + if out_vid is None: + log.critical("Guardrail blocked base generation.") + continue + + # Save video + if args.input_image_or_video_path: + out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + else: + out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4") + + imageio.mimsave(out_vid_path, out_vid, fps=25) + log.info(f"Saved video to {out_vid_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + torch._C._jit_set_texpr_fuser_enabled(False) + args = parse_args() + main(args) diff --git a/cosmos_predict1/autoregressive/inference/video2world.py b/cosmos_predict1/autoregressive/inference/video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..532919570a433ac21e76333db87822487c7e40b3 --- /dev/null +++ b/cosmos_predict1/autoregressive/inference/video2world.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import imageio +import torch + +from cosmos_predict1.autoregressive.inference.world_generation_pipeline import ARVideo2WorldGenerationPipeline +from cosmos_predict1.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args +from cosmos_predict1.utils import log +from cosmos_predict1.utils.io import read_prompts_from_file + + +def parse_args(): + parser = argparse.ArgumentParser(description="Prompted video to world generation demo script") + add_common_arguments(parser) + parser.add_argument( + "--ar_model_dir", + type=str, + default="Cosmos-Predict1-5B-Video2World", + ) + parser.add_argument( + "--input_type", + type=str, + default="text_and_video", + choices=["text_and_image", "text_and_video"], + help="Input types", + ) + parser.add_argument( + "--prompt", + type=str, + help="Text prompt for generating a single video", + ) + parser.add_argument( + "--offload_text_encoder_model", + action="store_true", + help="Offload T5 model after inference", + ) + args = parser.parse_args() + return args + + +def main(args): + """Run prompted video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (temperature, top_p) + - Input/output settings (images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + inference_type = "video2world" # When the inference_type is "video2world", AR model takes both text and video as input, the world generation is based on the input text prompt and video + sampling_config = validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + + # Initialize prompted base generation model pipeline + pipeline = ARVideo2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.ar_model_dir, + disable_diffusion_decoder=args.disable_diffusion_decoder, + offload_guardrail_models=args.offload_guardrail_models, + offload_diffusion_decoder=args.offload_diffusion_decoder, + offload_network=args.offload_ar_model, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + disable_guardrail=args.disable_guardrail, + parallel_size=args.num_gpus, + ) + + # Load input image(s) or video(s) + input_videos = load_vision_input( + input_type=args.input_type, + batch_input_path=args.batch_input_path, + input_image_or_video_path=args.input_image_or_video_path, + data_resolution=args.data_resolution, + num_input_frames=args.num_input_frames, + ) + # Load input prompt(s) + if args.batch_input_path: + prompts_list = read_prompts_from_file(args.batch_input_path) + else: + prompts_list = [{"visual_input": args.input_image_or_video_path, "prompt": args.prompt}] + + # Iterate through prompts + for idx, prompt_entry in enumerate(prompts_list): + video_path = prompt_entry["visual_input"] + input_filename = os.path.basename(video_path) + + # Check if video exists in loaded videos + if input_filename not in input_videos: + log.critical(f"Input file {input_filename} not found, skipping prompt.") + continue + + inp_vid = input_videos[input_filename] + inp_prompt = prompt_entry["prompt"] + + # Generate video + log.info(f"Run with input: {prompt_entry}") + out_vid = pipeline.generate( + inp_prompt=inp_prompt, + inp_vid=inp_vid, + num_input_frames=args.num_input_frames, + seed=args.seed, + sampling_config=sampling_config, + ) + if out_vid is None: + log.critical("Guardrail blocked video2world generation.") + continue + + # Save video + if args.input_image_or_video_path: + out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + else: + out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4") + imageio.mimsave(out_vid_path, out_vid, fps=25) + + log.info(f"Saved video to {out_vid_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + torch._C._jit_set_texpr_fuser_enabled(False) + args = parse_args() + main(args) diff --git a/cosmos_predict1/autoregressive/inference/world_generation_pipeline.py b/cosmos_predict1/autoregressive/inference/world_generation_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a8fd5b392395d777c120528241dbf89aef050efa --- /dev/null +++ b/cosmos_predict1/autoregressive/inference/world_generation_pipeline.py @@ -0,0 +1,1031 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from einops import rearrange +from megatron.core import ModelParallelConfig, parallel_state +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + +from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model_config +from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig +from cosmos_predict1.autoregressive.configs.inference.inference_config import ( + DataShapeConfig, + DiffusionDecoderSamplingConfig, + InferenceConfig, + SamplingConfig, +) +from cosmos_predict1.autoregressive.diffusion_decoder.inference import diffusion_decoder_process_tokens +from cosmos_predict1.autoregressive.diffusion_decoder.model import LatentDiffusionDecoderModel +from cosmos_predict1.autoregressive.model import AutoRegressiveModel, update_model_config +from cosmos_predict1.autoregressive.utils.inference import _SUPPORTED_CONTEXT_LEN, prepare_video_batch_for_saving +from cosmos_predict1.autoregressive.utils.parallel import broadcast_data_batch_in_tp_cp_group, get_batch_on_this_cp_rank +from cosmos_predict1.diffusion.inference.inference_utils import ( + load_model_by_config, + load_network_model, + load_tokenizer_model, +) +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.base_world_generation_pipeline import BaseWorldGenerationPipeline + + +def detect_model_size_from_ckpt_path(ckpt_path: str) -> str: + """Detect model size from checkpoint path. + + Args: + ckpt_path: Path to model checkpoint file + + Returns: + str: Model size ('4b', '5b', '12b', or '13b') + + Examples: + >>> detect_model_size_from_ckpt_path("model_4B.pt") + '4b' + """ + model_size = "4b" + if "4B" in ckpt_path: + model_size = "4b" + elif "5B" in ckpt_path: + model_size = "5b" + elif "12B" in ckpt_path: + model_size = "12b" + elif "13B" in ckpt_path: + model_size = "13b" + else: + log.warning(f"Could not detect model size from checkpoint path: {ckpt_path}") + return model_size + + +def create_inference_config( + model_ckpt_path: str, + tokenizer_ckpt_path: str, + model_size: str = "4b", + parallel_size: int = 4, + batch_size: int = 1, + inference_type: str = "base", +) -> InferenceConfig: + """Create inference configuration for model. + + Args: + model_ckpt_path: Path to model checkpoint + tokenizer_ckpt_path: Path to tokenizer checkpoint + model_size: Size of model ('4b', '5b', '12b', '13b') + parallel_size: Number of GPUs for parallelism + batch_size: Batch size for inference + inference_type: Type of inference ('base' or 'video2world') + + Returns: + InferenceConfig: Configuration object for inference + """ + model_size = model_size.lower() + # For inference config + kwargs = {} + if inference_type == "video2world": + kwargs.update( + dict( + insert_cross_attn=True, + insert_cross_attn_every_k_layers=1, + context_dim=1024, + training_type="text_to_video", + apply_abs_pos_emb=True, + ) + ) + if model_size == "5b": + model_size = "4b" # The base model (excluding the cross attention layers) is the 4B model + elif model_size == "13b": + model_size = "12b" # The base model (excluding the cross attention layers) is the 12B model + else: + raise ValueError(f"Unsupported model size for video2world inference_type: {model_size}") + else: + assert inference_type == "base", f"Unsupported inference_type: {inference_type}" + + model_config, tokenizer_config = create_video2world_model_config( + model_ckpt_path=model_ckpt_path, + tokenizer_ckpt_path=tokenizer_ckpt_path, + model_size=model_size, + tensor_model_parallel_size=parallel_size, + rope_dim="3D", + add_special_tokens=False, + pixel_chunk_duration=33, + num_video_frames=33, + num_condition_latents_t=1, + batch_size=batch_size, + video_height=640, + video_width=1024, + **kwargs, + ) + + inference_config = InferenceConfig() + + inference_config.model_config = model_config + inference_config.tokenizer_config = tokenizer_config + + inference_config.data_shape_config = DataShapeConfig( + num_video_frames=model_config.num_video_frames, + height=model_config.video_height, + width=model_config.video_width, + latent_shape=model_config.video_latent_shape, + ) + inference_config.model_config.fuse_qkv = False + return inference_config + + +class ARBaseGenerationPipeline(BaseWorldGenerationPipeline): + """Base class for autoregressive world generation models. + + Handles the core functionality for generating videos using autoregressive models. + Provides configurable GPU memory management through model offloading and supports + different inference types for video generation. + + Attributes: + inference_config (InferenceConfig): Configuration for model inference + tokenizer_config (TokenizerConfig): Configuration for tokenizer + disable_diffusion_decoder (bool): Whether diffusion decoder is disabled + parallel_size (int): Number of GPUs for parallelism + latent_shape (List[int]): Shape of video latents [T, H, W] + _supported_context_len (int): Supported context window length + latent_chunk_duration (int): Duration of latent chunks + pixel_chunk_duration (int): Duration of pixel chunks + diffusion_decoder_model (Optional[nn.Module]): The diffusion decoder model + """ + + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + has_text_input: bool = False, + offload_network: bool = False, + offload_tokenizer: bool = False, + disable_diffusion_decoder: bool = False, + offload_guardrail_models: bool = False, + offload_diffusion_decoder: bool = False, + disable_guardrail: bool = False, + parallel_size: int = 1, + ): + """Initialize the autoregressive world generation pipeline. + + Args: + inference_type: Type of world generation ('base' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the AR checkpoint to load + has_text_input: Whether the pipeline takes text input for world generation + disable_diffusion_decoder: Whether to disable the diffusion decoder stage + offload_network: Whether to offload AR model from GPU after use + offload_guardrail_models: Whether to offload content filtering models + offload_diffusion_decoder: Whether to offload diffusion decoder + disable_guardrail: Whether to disable guardrail + parallel_size: Number of GPUs for parallelism + + Raises: + AssertionError: If inference_type is not 'base' or 'video2world' + """ + assert inference_type in [ + "base", + "video2world", + ], "Invalid inference_type, must be 'base' or 'video2world'" + + # Create inference config + model_size = detect_model_size_from_ckpt_path(checkpoint_name) + model_ckpt_path = os.path.join(checkpoint_dir, checkpoint_name, "model.pt") + tokenizer_ckpt_path = os.path.join(checkpoint_dir, "Cosmos-Tokenize1-DV8x16x16-720p/ema.jit") + + inference_config: InferenceConfig = create_inference_config( + model_ckpt_path=model_ckpt_path, + tokenizer_ckpt_path=tokenizer_ckpt_path, + model_size=model_size, + parallel_size=parallel_size, + inference_type=inference_type, + ) + + self.inference_config = inference_config + self.parallel_size = parallel_size + self.disable_diffusion_decoder = disable_diffusion_decoder + + if not disable_diffusion_decoder: + self.diffusion_decoder_ckpt_path = os.path.join( + checkpoint_dir, "Cosmos-Predict1-7B-Decoder-DV8x16x16ToCV8x8x8-720p/model.pt" + ) + self.diffusion_decoder_config = "DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token" + self.diffusion_decoder_tokenizer_path = os.path.join(checkpoint_dir, "Cosmos-Tokenize1-CV8x8x8-720p") + self.dd_sampling_config = DiffusionDecoderSamplingConfig() + aux_vars_path = os.path.join(os.path.dirname(self.diffusion_decoder_ckpt_path), "aux_vars.pt") + # We use a generic prompt when no text prompts are available for diffusion decoder. + # Generic prompt used - "high quality, 4k, high definition, smooth video" + aux_vars = torch.load(aux_vars_path, weights_only=True) + self.generic_prompt = dict() + self.generic_prompt["context"] = aux_vars["context"].cuda() + self.generic_prompt["context_mask"] = aux_vars["context_mask"].cuda() + + self.latent_shape = inference_config.data_shape_config.latent_shape # [L, 40, 64] + self._supported_context_len = _SUPPORTED_CONTEXT_LEN + self.tokenizer_config = inference_config.tokenizer_config + + self.offload_diffusion_decoder = offload_diffusion_decoder + self.diffusion_decoder_model = None + if not self.offload_diffusion_decoder and not disable_diffusion_decoder: + self._load_diffusion_decoder() + + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + has_text_input=has_text_input, + offload_guardrail_models=offload_guardrail_models, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + disable_guardrail=disable_guardrail, + offload_text_encoder_model=True, + ) + + def _load_model(self): + """Load and initialize the autoregressive model. + + Sets up parallelism if enabled (parallel_size > 1). + Initializes model parallel state and seeds for reproducibility. + Creates and configures the autoregressive model with appropriate settings. + """ + if self.parallel_size > 1: + model_parallel = ModelParallelConfig( + tensor_model_parallel_size=self.parallel_size, + context_parallel_size=1, + bf16=True, + params_dtype=getattr(torch, "bfloat16"), + ) + parallel_state.initialize_model_parallel(tensor_model_parallel_size=self.parallel_size) + model_parallel_cuda_manual_seed(0) + parallel_state.destroy_model_parallel() + else: + model_parallel = None + self.model_config = self.inference_config.model_config + self.model_config = update_model_config( + self.model_config, + inference_tensor_parallel_size=self.parallel_size, + ) + self.model = AutoRegressiveModel( + config=self.inference_config.model_config, + model_parallel=model_parallel, + ) + + def _load_network(self): + """Load network weights for the autoregressive model. + + Sets up distributed training if available and handles checkpoint loading. + Supports tensor parallel model sharding when enabled. Coordinates across + distributed process groups if needed. + """ + if dist.is_available() and dist.is_initialized(): + # ddp_group = parallel_state.get_data_parallel_group() + # tp_group = parallel_state.get_tensor_model_parallel_group() + # dist.barrier(group=ddp_group) + # dist.barrier(group=tp_group) + pass + if "{rank}" in self.model_config.ckpt_path: + shard_checkpoint = False + else: + shard_checkpoint = ( + dist.is_available() and dist.is_initialized() + ) # Take the TP-rank specific checkpoint when initializing the model + + if self.parallel_size > 1: + parallel_state.initialize_model_parallel(tensor_model_parallel_size=self.parallel_size) + self.model.load_ar_model( + shard_checkpoint=shard_checkpoint, tokenizer_config=self.inference_config.tokenizer_config + ) + if self.parallel_size > 1: + parallel_state.destroy_model_parallel() + + def _load_tokenizer(self): + """Load and initialize the tokenizer model. + + Configures the tokenizer using settings from inference_config and + attaches it to the autoregressive model. + """ + self.model.load_tokenizer(tokenizer_config=self.inference_config.tokenizer_config) + + def _load_diffusion_decoder(self): + """Load and initialize the diffusion decoder model. + + Sets up context parallelism if enabled. Loads model weights, + and configures parallel processing groups as needed. + Handles model parallel state initialization and management. + """ + self.diffusion_decoder_model = load_model_by_config( + config_job_name=self.diffusion_decoder_config, + config_file="cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py", + model_class=LatentDiffusionDecoderModel, + ) + load_network_model(self.diffusion_decoder_model, self.diffusion_decoder_ckpt_path) + load_tokenizer_model(self.diffusion_decoder_model, self.diffusion_decoder_tokenizer_path) + + def _offload_diffusion_decoder(self): + """Offload diffusion decoder model from GPU memory.""" + if self.diffusion_decoder_model is not None: + del self.diffusion_decoder_model + self.diffusion_decoder_model = None + gc.collect() + torch.cuda.empty_cache() + + def _run_model_with_offload( + self, inp_vid: torch.Tensor, num_input_frames: int, seed: int, sampling_config: SamplingConfig + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Run the autoregressive model to generate video tokens. + + Takes input video frames and generates new video tokens using the autoregressive model. + Handles context frame selection and token generation. + + Args: + inp_vid (torch.Tensor): Input video tensor of shape + num_input_frames (int): Number of context frames to use from input. The tensor shape should be (B x T x 3 x H x W). + seed (int): Random seed for generation + sampling_config (SamplingConfig): Configuration for sampling parameters + + Returns: + tuple: ( + List of generated video tensors, + List of token index tensors, + List of prompt embedding tensors + ) + """ + # Choosing the context length from list of available contexts + out_videos_cur_batch, indices_tensor_cur_batch = self._run_model( + inp_vid, num_input_frames, seed, sampling_config + ) + + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._offload_tokenizer() + return out_videos_cur_batch, indices_tensor_cur_batch + + def _run_model( + self, inp_vid: torch.Tensor, num_input_frames: int, seed: int, sampling_config: SamplingConfig + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Run the autoregressive model to generate video tokens. + + Takes input video frames and generates new video tokens using the autoregressive model. + Handles context frame selection and token generation. + + Args: + inp_vid (torch.Tensor): Input video tensor of shape + num_input_frames (int): Number of context frames to use from input. The tensor shape should be (B x T x 3 x H x W). + seed (int): Random seed for generation + sampling_config (SamplingConfig): Configuration for sampling parameters + + Returns: + tuple: ( + List of generated video tensors, + List of token index tensors, + List of prompt embedding tensors + ) + """ + if self.parallel_size > 1: + parallel_state.initialize_model_parallel(tensor_model_parallel_size=self.parallel_size) + + # Choosing the context length from list of available contexts + latent_context_t_size = 0 + context_used = 0 + for _clen in self._supported_context_len: + if num_input_frames >= _clen: + context_used = _clen + latent_context_t_size += 1 + log.info(f"Using input size of {context_used} frames") + + data_batch = {"video": inp_vid} + data_batch = misc.to(data_batch, "cuda") + + T, H, W = self.latent_shape + num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W])) + + out_videos_cur_batch, indices_tensor_cur_batch = self.generate_partial_tokens_from_data_batch( + data_batch=data_batch, + num_tokens_to_generate=num_gen_tokens, + sampling_config=sampling_config, + tokenizer_config=self.tokenizer_config, + latent_shape=self.latent_shape, + task_condition="video", + num_chunks_to_generate=1, + seed=seed, + ) + + if self.parallel_size > 1: + parallel_state.destroy_model_parallel() + + return out_videos_cur_batch, indices_tensor_cur_batch + + def _run_diffusion_decoder( + self, + out_videos_cur_batch: List[torch.Tensor], + indices_tensor_cur_batch: List[torch.Tensor], + t5_emb_batch: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Process generated tokens through the diffusion decoder. + + Enhances video quality through diffusion-based decoding. + + Args: + out_videos_cur_batch: List of generated video tensors + indices_tensor_cur_batch: List of token indices tensors + t5_emb_batch: List of text embeddings for conditioning + + Returns: + list: Enhanced video tensors after diffusion processing + """ + if self.parallel_size > 1: + parallel_state.initialize_model_parallel(context_parallel_size=self.parallel_size) + process_group = parallel_state.get_context_parallel_group() + self.diffusion_decoder_model.net.enable_context_parallel(process_group) + + out_videos_cur_batch_dd = diffusion_decoder_process_tokens( + model=self.diffusion_decoder_model, + indices_tensor=indices_tensor_cur_batch, + dd_sampling_config=self.dd_sampling_config, + original_video_example=out_videos_cur_batch[0], + t5_emb_batch=t5_emb_batch, + ) + + if self.parallel_size > 1: + parallel_state.destroy_model_parallel() + + return out_videos_cur_batch_dd + + def _run_diffusion_decoder_with_offload( + self, + out_videos_cur_batch: List[torch.Tensor], + indices_tensor_cur_batch: List[torch.Tensor], + t5_emb_batch: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Run diffusion decoder with memory management. + + Loads decoder if needed, processes videos, and offloads decoder afterward + if configured in offload_diffusion_decoder. + + Args: + out_videos_cur_batch: List of generated video tensors + indices_tensor_cur_batch: List of token indices tensors + t5_emb_batch: List of text embeddings for conditioning + + Returns: + list: Enhanced video tensors after diffusion processing + """ + if self.offload_diffusion_decoder: + self._load_diffusion_decoder() + out_videos_cur_batch = self._run_diffusion_decoder(out_videos_cur_batch, indices_tensor_cur_batch, t5_emb_batch) + if self.offload_diffusion_decoder: + self._offload_diffusion_decoder() + return out_videos_cur_batch + + def generate( + self, + inp_vid: torch.Tensor, + sampling_config: SamplingConfig, + num_input_frames: int = 9, + seed: int = 0, + ) -> np.ndarray | None: + """Generate a video continuation from input frames. + + Pipeline steps: + 1. Generates video tokens using autoregressive model + 2. Optionally enhances quality via diffusion decoder + 3. Applies safety checks if enabled + + Args: + inp_vid: Input video tensor of shape (batch_size, time, channels=3, height, width) + sampling_config: Parameters controlling the generation process + num_input_frames: Number of input frames to use as context (default: 9) + seed: Random seed for reproducibility (default: 0) + + Returns: + np.ndarray | None: Generated video as numpy array (time, height, width, channels) + if generation successful, None if safety checks fail + """ + log.info("Run generation") + out_videos_cur_batch, indices_tensor_cur_batch = self._run_model_with_offload( + inp_vid, num_input_frames, seed, sampling_config + ) + log.info("Finish AR model generation") + + if not self.disable_diffusion_decoder: + log.info("Run diffusion decoder on generated tokens") + out_videos_cur_batch = self._run_diffusion_decoder_with_offload( + out_videos_cur_batch, indices_tensor_cur_batch, t5_emb_batch=[self.generic_prompt["context"]] + ) + log.info("Finish diffusion decoder on generated tokens") + out_videos_cur_batch = prepare_video_batch_for_saving(out_videos_cur_batch) + output_video = out_videos_cur_batch[0] + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + output_video = self._run_guardrail_on_video_with_offload(output_video) + if output_video is None: + log.critical("Generated video is not safe") + return None + log.info("Finish guardrail on generated video") + + return output_video + + @torch.inference_mode() + def generate_partial_tokens_from_data_batch( + self, + data_batch: dict, + num_tokens_to_generate: int, + sampling_config: SamplingConfig, + tokenizer_config: TokenizerConfig, + latent_shape: list[int], + task_condition: str, + num_chunks_to_generate: int = 1, + seed: int = 0, + ) -> tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """Generate video tokens from partial input tokens with conditioning. + + Handles token generation and decoding process: + 1. Processes input batch and applies conditioning + 2. Generates specified number of new tokens + 3. Decodes tokens to video frames + + Args: + data_batch: Dictionary containing input data including video and optional context + num_tokens_to_generate: Number of tokens to generate + sampling_config: Configuration for sampling parameters + tokenizer_config: Configuration for tokenizer, including video tokenizer settings + latent_shape: Shape of video latents [T, H, W] + task_condition: Type of generation task ('video' or 'text_and_video') + num_chunks_to_generate: Number of chunks to generate (default: 1) + seed: Random seed for generation (default: 0) + + Returns: + tuple containing: + - List[torch.Tensor]: Generated videos + - List[torch.Tensor]: Input videos + - List[torch.Tensor]: Generated tokens + - List[torch.Tensor]: Token index tensors + """ + log.debug(f"Starting generate_partial_tokens_from_data_batch with seed {seed}") + log.debug(f"Number of tokens to generate: {num_tokens_to_generate}") + log.debug(f"Latent shape: {latent_shape}") + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + broadcast_data_batch_in_tp_cp_group(data_batch) + + video_token_start = tokenizer_config.video_tokenizer.tokenizer_offset + video_vocab_size = tokenizer_config.video_tokenizer.vocab_size + video_token_end = video_token_start + video_vocab_size + + logit_clipping_range = [video_token_start, video_token_end] + + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._load_tokenizer() + + assert logit_clipping_range == [ + 0, + self.model.tokenizer.video_vocab_size, + ], f"logit_clipping_range {logit_clipping_range} is not supported for fast generate. Expected [0, {self.model.tokenizer.video_vocab_size}]" + + out_videos = {} + out_indices_tensors = {} + + # for text2world, we only add a token at the beginning of the video tokens, this applies to 5B and 13B models + if self.model.tokenizer.tokenizer_config.training_type == "text_to_video": + num_bov_tokens = 1 + num_eov_tokens = 0 + else: + num_eov_tokens = 1 if self.model.tokenizer.tokenizer_config.add_special_tokens else 0 + num_bov_tokens = 1 if self.model.tokenizer.tokenizer_config.add_special_tokens else 0 + + chunk_idx = 0 + out_videos[chunk_idx] = [] + out_indices_tensors[chunk_idx] = [] + + # get the context embedding and mask + context = data_batch.get("context", None) if task_condition != "video" else None + context_mask = data_batch.get("context_mask", None) if task_condition != "video" else None + if context is not None: + context = misc.to(context, "cuda").detach().clone() + if context_mask is not None: + context_mask = misc.to(context_mask, "cuda").detach().clone() + + # get the video tokens + data_tokens, token_boundaries = self.model.tokenizer.tokenize(data_batch=data_batch) + data_tokens = misc.to(data_tokens, "cuda").detach().clone() + if parallel_state.get_context_parallel_world_size() > 1: + data_tokens = get_batch_on_this_cp_rank(data_tokens) + batch_size = data_tokens.shape[0] + + for sample_num in range(batch_size): + input_tokens = data_tokens[sample_num][0 : token_boundaries["video"][sample_num][1]] # [B, L] + input_tokens = [ + input_tokens[0 : -num_tokens_to_generate - num_eov_tokens].tolist() + ] # -1 is to exclude eov token + log.debug( + f"Run sampling. # input condition tokens: {len(input_tokens[0])}; # generate tokens: {num_tokens_to_generate + num_eov_tokens}; " + f"full length of the data tokens: {len(data_tokens[sample_num])}: {data_tokens[sample_num]}" + ) + video_start_boundary = token_boundaries["video"][sample_num][0] + num_bov_tokens + + video_decoded, indices_tensor = self.generate_video_from_tokens( + prompt_tokens=input_tokens, + latent_shape=latent_shape, + video_start_boundary=video_start_boundary, + max_gen_len=num_tokens_to_generate, + sampling_config=sampling_config, + logit_clipping_range=logit_clipping_range, + seed=seed, + context=context, + context_mask=context_mask, + ) # BCLHW, range [0, 1] + + # For the first chunk, we store the entire generated video + out_videos[chunk_idx].append(video_decoded[sample_num].detach().clone()) + out_indices_tensors[chunk_idx].append(indices_tensor[sample_num].detach().clone()) + + output_videos = [] + output_indice_tensors = [] + for sample_num in range(len(out_videos[0])): + tensors_to_concat = [out_videos[chunk_idx][sample_num] for chunk_idx in range(num_chunks_to_generate)] + concatenated = torch.cat(tensors_to_concat, dim=1) + output_videos.append(concatenated) + + indices_tensor_to_concat = [ + out_indices_tensors[chunk_idx][sample_num] for chunk_idx in range(num_chunks_to_generate) + ] + concatenated_indices_tensor = torch.cat(indices_tensor_to_concat, dim=1) # BLHW + output_indice_tensors.append(concatenated_indices_tensor) + + return output_videos, output_indice_tensors + + def generate_video_from_tokens( + self, + prompt_tokens: list[torch.Tensor], + latent_shape: list[int], + video_start_boundary: int, + max_gen_len: int, + sampling_config: SamplingConfig, + logit_clipping_range: list[int], + seed: int = 0, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Function to generate video from input tokens. These input tokens can be initial text tokens (in case of text to video), + or partial ground truth tokens. + + Handles the core token-to-video generation process: + 1. Generates new tokens using the autoregressive model + 2. Handles padding and token sequence completion + 3. Reshapes and processes generated tokens + 4. Decodes final tokens into video frames + + Args: + model (AutoRegressiveModel): LLama model instance + prompt_tokens (list): Prompt tokens used by the model + latent_shape (list): Shape of the video latents + video_start_boundary (int): Index where the video tokens start + max_gen_len (int): Maximum length of the tokens that needs to be generated + sampling_config (SamplingConfig): Config used by sampler during inference + logit_clipping_range (list): Range of indices in the logits to be clipped, e.g. [video_token_start, video_token_end] + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + Returns: + tuple containing: + - List[torch.Tensor]: Generated videos + - List[torch.Tensor]: Generated tokens + - List[torch.Tensor]: Token index tensors + """ + # Combine the tokens and do padding, sometimes the generated tokens end before the max_gen_len + total_seq_len = np.prod(latent_shape) + + assert not sampling_config.logprobs + + stop_tokens = self.model.tokenizer.stop_tokens + if self.offload_tokenizer: + self._offload_tokenizer() + if self.offload_network: + self._load_network() + + generation_tokens, _ = self.model.generate( + prompt_tokens=prompt_tokens, + temperature=sampling_config.temperature, + top_p=sampling_config.top_p, + echo=sampling_config.echo, + seed=seed, + context=context, + context_mask=context_mask, + max_gen_len=max_gen_len, + compile_sampling=sampling_config.compile_sampling, + compile_prefill=sampling_config.compile_prefill, + stop_tokens=stop_tokens, + verbose=True, + ) + generation_tokens = generation_tokens[:, video_start_boundary:] + # Combine the tokens and do padding, sometimes the generated tokens end before the max_gen_len + if generation_tokens.shape[1] < total_seq_len: + log.warning( + f"Generated video tokens (shape:{generation_tokens.shape}) shorted than expected {total_seq_len}. Could be the model produce end token early. Repeat the last token to fill the sequence in order for decoding." + ) + padding_len = total_seq_len - generation_tokens.shape[1] + padding_tokens = generation_tokens[:, [-1]].repeat(1, padding_len) + generation_tokens = torch.cat([generation_tokens, padding_tokens], dim=1) + # Cast to LongTensor + indices_tensor = generation_tokens.long() + # First, we reshape the generated tokens into batch x time x height x width + indices_tensor = rearrange( + indices_tensor, + "B (T H W) -> B T H W", + T=latent_shape[0], + H=latent_shape[1], + W=latent_shape[2], + ) + log.debug(f"generated video tokens {len(generation_tokens[0])} -> reshape: {indices_tensor.shape}") + # If logit clipping range is specified, offset the generated indices by the logit_clipping_range[0] + # Video decoder always takes tokens in the range (0, N-1). So, this offset is needed. + if len(logit_clipping_range) > 0: + indices_tensor = indices_tensor - logit_clipping_range[0] + + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._load_tokenizer() + + # Now decode the video using tokenizer. + video_decoded = self.model.tokenizer.video_tokenizer.decode(indices_tensor.cuda()) + # Normalize decoded video from [-1, 1] to [0, 1], and clip value + video_decoded = (video_decoded * 0.5 + 0.5).clamp_(0, 1) + return video_decoded, indices_tensor + + +class ARVideo2WorldGenerationPipeline(ARBaseGenerationPipeline): + """Video-to-world generation pipeline with text conditioning capabilities. + + Extends the base autoregressive generation pipeline by adding: + - Text prompt processing and embedding + - Text-conditioned video generation + - Additional safety checks for text input + - Memory management for text encoder model + + Enables generating video continuations that are guided by both + input video frames and text descriptions. + + Additional attributes compared to ARBaseGenerationPipeline: + offload_text_encoder_model (bool): Whether to offload text encoder from GPU after use + """ + + def __init__( + self, + checkpoint_dir: str, + checkpoint_name: str, + inference_type: str = None, + has_text_input: bool = True, + disable_diffusion_decoder: bool = False, + offload_guardrail_models: bool = False, + offload_diffusion_decoder: bool = False, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + disable_guardrail: bool = False, + parallel_size: int = 1, + ): + """Initialize text-conditioned video generation pipeline. + + Args: + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the checkpoint to load + inference_type: Type of world generation workflow + has_text_input: Whether the pipeline takes text input for world generation + disable_diffusion_decoder: Whether to disable diffusion decoder stage + offload_guardrail_models: Whether to offload content filtering models + offload_diffusion_decoder: Whether to offload diffusion decoder + offload_network: Whether to offload AR model from GPU + offload_tokenizer: Whether to offload tokenizer from GPU + disable_guardrail: Whether to disable guardrail + offload_text_encoder_model: Whether to offload text encoder + parallel_size: Number of GPUs for parallelism + """ + super().__init__( + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + inference_type=inference_type, + has_text_input=has_text_input, + disable_diffusion_decoder=disable_diffusion_decoder, + offload_guardrail_models=offload_guardrail_models, + offload_diffusion_decoder=offload_diffusion_decoder, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + disable_guardrail=disable_guardrail, + parallel_size=parallel_size, + ) + self.offload_text_encoder_model = offload_text_encoder_model + if not self.offload_text_encoder_model: + self._load_text_encoder_model() + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + prompt_mask: torch.Tensor, + inp_vid: torch.Tensor, + num_input_frames: int, + seed: int, + sampling_config: SamplingConfig, + ) -> tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """Run model generation with memory management. + + Executes generation process and handles model offloading to manage GPU memory. + + Args: + prompt_embedding: Text prompt embeddings tensor + prompt_mask: Attention mask for prompt embeddings + inp_vid: Input video tensor + num_input_frames: Number of input frames to use + seed: Random seed for reproducibility + sampling_config: Configuration for sampling parameters + + Returns: + tuple: ( + List of generated video tensors + List of token index tensors + List of prompt embedding tensors + ) + """ + out_videos, indices_tensor, prompt_embedding = self._run_model( + prompt_embedding, prompt_mask, inp_vid, num_input_frames, seed, sampling_config + ) + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._offload_tokenizer() + return out_videos, indices_tensor, prompt_embedding + + def _run_model( + self, + prompt_embedding: torch.Tensor, + prompt_mask: torch.Tensor, + inp_vid: torch.Tensor, + num_input_frames: int, + seed: int, + sampling_config: SamplingConfig, + ) -> tuple[List[torch.Tensor], List[torch.Tensor], torch.Tensor]: + """Run core model generation process. + + Handles text-conditioned video generation: + 1. Prepares data batch with text embeddings and video + 2. Determines appropriate context length + 3. Generates video tokens with text conditioning + 4. Processes output tensors + + Args: + prompt_embedding: Text prompt embeddings tensor + prompt_mask: Attention mask for prompt embeddings + inp_vid: Input video tensor + num_input_frames: Number of input frames to use + seed: Random seed for reproducibility + sampling_config: Configuration for sampling parameters, + uses default config if None + + Returns: + tuple: ( + List of generated video tensors + List of token index tensors + Text context tensor + ) + """ + if self.parallel_size > 1: + parallel_state.initialize_model_parallel(tensor_model_parallel_size=self.parallel_size) + + data_batch = {} + data_batch["context"], data_batch["context_mask"] = prompt_embedding, prompt_mask + T, H, W = self.latent_shape + + if sampling_config is None: + sampling_config = self.sampling_config + if type(inp_vid) is list: + batch_size = len(inp_vid) + elif type(inp_vid) is torch.Tensor: + batch_size = 1 + data_batch["context"] = data_batch["context"].repeat(batch_size, 1, 1) + data_batch["context_mask"] = data_batch["context_mask"].repeat(batch_size, 1) + data_batch["context_mask"] = torch.ones_like(data_batch["context_mask"]).bool() + + latent_context_t_size = 0 + + # Choosing the context length from list of available contexts + context_used = 0 + for _clen in self._supported_context_len: + if num_input_frames >= _clen: + context_used = _clen + latent_context_t_size += 1 + log.info(f"Using context of {context_used} frames") + + num_gen_tokens = int(np.prod([T - latent_context_t_size, H, W])) + + data_batch["video"] = inp_vid + data_batch["video"] = data_batch["video"].repeat(batch_size, 1, 1, 1, 1) + + data_batch = misc.to(data_batch, "cuda") + + log.debug(f" num_tokens_to_generate: {num_gen_tokens}") + log.debug(f" sampling_config: {sampling_config}") + log.debug(f" tokenizer_config: {self.tokenizer_config}") + log.debug(f" latent_shape: {self.latent_shape}") + log.debug(f" latent_context_t_size: {latent_context_t_size}") + log.debug(f" seed: {seed}") + + out_videos_cur_batch, indices_tensor_cur_batch = self.generate_partial_tokens_from_data_batch( + data_batch=data_batch, + num_tokens_to_generate=num_gen_tokens, + sampling_config=sampling_config, + tokenizer_config=self.tokenizer_config, + latent_shape=self.latent_shape, + task_condition="text_and_video", + seed=seed, + ) + + if self.parallel_size > 1: + parallel_state.destroy_model_parallel() + + return out_videos_cur_batch, indices_tensor_cur_batch, data_batch["context"] + + def generate( + self, + inp_prompt: str, + inp_vid: torch.Tensor, + num_input_frames: int = 9, + seed: int = 0, + sampling_config: SamplingConfig = None, + ) -> np.ndarray | None: + """Generate a video guided by text prompt and input frames. + + Pipeline steps: + 1. Validates text prompt safety if enabled + 2. Converts text to embeddings + 3. Generates video with text conditioning + 4. Enhances quality via diffusion decoder + 5. Applies video safety checks if enabled + + Args: + inp_prompt: Text prompt to guide the generation + inp_vid: Input video tensor with shape (batch_size, time, channels=3, height, width) + num_input_frames: Number of frames to use as context (default: 9) + seed: Random seed for reproducibility (default: 0) + sampling_config: Configuration for sampling parameters, + uses default config if None + + Returns: + np.ndarray | None: Generated video as numpy array (time, height, width, channels) + if generation successful, None if safety checks fail + """ + if not self.disable_guardrail: + log.info("Run guardrail on prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(inp_prompt) + if not is_safe: + log.critical("Input text prompt is not safe") + return None + log.info("Pass guardrail on prompt") + + log.info("Run text embedding on prompt") + prompt_embeddings, prompt_masks = self._run_text_embedding_on_prompt_with_offload([inp_prompt]) + prompt_embedding = prompt_embeddings[0] + prompt_mask = prompt_masks[0] + log.info("Finish text embedding on prompt") + + log.info("Run generation") + out_videos_cur_batch, indices_tensor_cur_batch, prompt_embedding = self._run_model_with_offload( + prompt_embedding, prompt_mask, inp_vid, num_input_frames, seed, sampling_config + ) + log.info("Finish AR model generation") + + if not self.disable_diffusion_decoder: + log.info("Run diffusion decoder on generated tokens") + out_videos_cur_batch = self._run_diffusion_decoder_with_offload( + out_videos_cur_batch, indices_tensor_cur_batch, [prompt_embedding] + ) + log.info("Finish diffusion decoder on generated tokens") + out_videos_cur_batch = prepare_video_batch_for_saving(out_videos_cur_batch) + output_video = out_videos_cur_batch[0] + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + output_video = self._run_guardrail_on_video_with_offload(output_video) + if output_video is None: + log.critical("Generated video is not safe") + return None + log.info("Finish guardrail on generated video") + + return output_video diff --git a/cosmos_predict1/autoregressive/model.py b/cosmos_predict1/autoregressive/model.py new file mode 100644 index 0000000000000000000000000000000000000000..38179ea5a952600f733cbdc7c7a679ed9a737f8f --- /dev/null +++ b/cosmos_predict1/autoregressive/model.py @@ -0,0 +1,660 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +import torch +from megatron.core import ModelParallelConfig, parallel_state +from safetensors.torch import load_file +from torch.nn.modules.module import _IncompatibleKeys + +from cosmos_predict1.autoregressive.configs.base.model import ModelConfig +from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig +from cosmos_predict1.autoregressive.modules.mm_projector import MultimodalProjector +from cosmos_predict1.autoregressive.networks.transformer import Transformer +from cosmos_predict1.autoregressive.networks.vit import VisionTransformer, get_vit_config +from cosmos_predict1.autoregressive.tokenizer.tokenizer import DiscreteMultimodalTokenizer, update_vocab_size +from cosmos_predict1.autoregressive.utils.checkpoint import ( + get_partial_state_dict, + obtain_tensor_parallel_state_dict, + process_state_dict, + substrings_to_ignore, +) +from cosmos_predict1.autoregressive.utils.sampling import decode_n_tokens, decode_one_token, prefill +from cosmos_predict1.utils import log, misc + + +def update_model_config(model_config, inference_tensor_parallel_size): + if inference_tensor_parallel_size > 1: + log.warning(f"Setting tensor parallel size to {inference_tensor_parallel_size}") + setattr( + model_config, + "tensor_model_parallel_size", + inference_tensor_parallel_size, + ) + + if "{rank}" in model_config.ckpt_path: + tp_rank = parallel_state.get_tensor_model_parallel_rank() + model_config.ckpt_path = model_config.ckpt_path.format(rank=tp_rank) + return model_config + + +class AutoRegressiveModel(torch.nn.Module): + """ + A class to build and use a AutoRegressiveModel model for text generation. + + Methods: + build: Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. + generate: Generate text sequences based on provided prompts using the language generation model. + """ + + def __init__( + self, + model: Transformer = None, + tokenizer: DiscreteMultimodalTokenizer = None, + config: ModelConfig = None, + model_parallel: ModelParallelConfig = None, + vision_encoder: VisionTransformer = None, + mm_projector: MultimodalProjector = None, + ): + """ + Initialize the AutoRegressiveModel instance with a model and tokenizer. + + Args: + model (Transformer): The Transformer model for text generation. + tokenizer (Tokenizer): The tokenizer for encoding and decoding text. + config (Config): The configuration for the AutoRegressiveModel model. + model_parallel (ModelParallelConfig): The model parallel configuration for the AutoRegressiveModel model. + vision_encoder (VisionTransformer): The vision encoder for the AutoRegressiveModel model. + mm_projector (MultimodalProjector): The multi-modal projector for the AutoRegressiveModel model. + """ + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.config = config + + self.vision_encoder = vision_encoder + self.mm_projector = mm_projector + self.model_parallel = model_parallel + + @property + def precision(self): + return self.model.precision + + def get_num_params( + self, + ) -> int: + """ + Return the number of parameters in the model. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + def load_ar_model( + self, + shard_checkpoint, + tokenizer_config, + ): + """ + Load the AR model. + """ + model_config = self.config + tensor_parallel_size = 1 if self.model_parallel is None else self.model_parallel.tensor_model_parallel_size + assert tensor_parallel_size == model_config["tensor_model_parallel_size"] + ckpt_path = model_config.ckpt_path + with misc.timer(f"loading checkpoint from {ckpt_path}"): + if ckpt_path.endswith("safetensors"): + # Load with safetensors API + checkpoint = load_file(ckpt_path, device="cpu") + else: + # The pytorch version + checkpoint = torch.load( + ckpt_path, + map_location="cpu", + mmap=True, # load the checkpoint in memory-mapped mode + weights_only=True, + ) + llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint + orig_precision = torch.get_default_dtype() + precision = getattr(torch, model_config.precision) + torch.set_default_dtype(precision) + log.debug(f"Setting torch default dtype to {precision}") + + model = Transformer( + params=model_config, + model_parallel=self.model_parallel, + tokenizer_config=tokenizer_config, + ) + log.debug( + f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size}" + ) + vocab_size = update_vocab_size( + existing_vocab_size=0, + to_be_added_vocab_size=tokenizer_config.video_tokenizer.vocab_size, + training_type=tokenizer_config.training_type, + add_special_tokens=False, + ) + log.debug( + f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size} vocab_size {vocab_size}" + ) + # Perform vocab expansion + if vocab_size > model.vocab_size: + log.debug(f"Expanding vocab size to {vocab_size}") + # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, + expand_output_layer = not (tokenizer_config.training_type == "text_to_video") + model.expand_vocab( + vocab_size, + init_method="gaussian", + expand_output_layer=expand_output_layer, + ) + if shard_checkpoint: + # Shard the checkpoint according to tensor parallelism. + with misc.timer("sharding checkpoint according to tensor parallelism"): + if self.model_parallel is not None: + assert self.model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] + llm_checkpoint = obtain_tensor_parallel_state_dict( + llm_checkpoint, + tensor_parallel_size=tensor_parallel_size, + tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), + model_config=model_config, + ) + # Remove the "model." prefix in the state_dict + llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") + with misc.timer("loading state_dict into model"): + missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True) + # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) + missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + + self.model = model.to(precision).to("cuda") + torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value + + def load_tokenizer(self, tokenizer_config): + """ + Load the tokenizer. + """ + self.tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) + + @staticmethod + def build( + model_config: ModelConfig = ModelConfig(), + tokenizer_config: TokenizerConfig = None, + model_parallel: ModelParallelConfig = None, + shard_checkpoint: bool = False, + ) -> "AutoRegressiveModel": + """ + Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. + + Args: + model_config (ModelConfig, optional): The model configuration for the AutoRegressiveModel instance. Defaults to ModelConfig(). + tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the AutoRegressiveModel instance. Defaults to None. + shard_checkpoint (bool, optional): Whether to split the checkpoint by Tensor Parallelism before loading. Defaults to False. + download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True. + Returns: + AutoRegressiveModel: An instance of the AutoRegressiveModel class with the loaded model and tokenizer. + + Raises: + AssertionError: If there are no checkpoint files in the specified directory. + + Note: + This method sets the device to CUDA and loads the pre-trained model and tokenizer. + """ + tensor_parallel_size = 1 if model_parallel is None else model_parallel.tensor_model_parallel_size + assert tensor_parallel_size == model_config["tensor_model_parallel_size"] + + # Initialize model configuration parameters + config_params = {} + + # Load checkpoint and model parameters + + if model_config.ckpt_path is None: + # If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir + ckpt_dir = model_config.ckpt_dir + + # We prioritize safetensors version over the pytorch version, since the former is + # much faster for checkpoint loading. + checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) + if len(checkpoints) == 0: + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert ( + len(checkpoints) == 1 + ), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)" + ckpt_path = str(checkpoints[0]) # Assuming single checkpoint for non-parallel case + + if os.path.exists(Path(ckpt_dir) / "config.json"): + with open(Path(ckpt_dir) / "config.json", "r") as f: + config_params = json.loads(f.read()) + else: + log.info( + f"No params.json found in the checkpoint directory ({ckpt_dir}). " f"Using default model config." + ) + + else: + # If ckpt_path is provided, we load the model from the specified path, + # and use the default model configuration + ckpt_path = model_config.ckpt_path + + for key, value in config_params.items(): + if hasattr(model_config, key): + # Override the default model configuration with the parameters from the checkpoint + setattr(model_config, key, value) + + with misc.timer(f"loading checkpoint from {ckpt_path}"): + if ckpt_path.endswith("safetensors"): + # Load with safetensors API + checkpoint = load_file(ckpt_path, device="cpu") + else: + # The pytorch version + checkpoint = torch.load( + ckpt_path, + map_location="cpu", + mmap=True, # load the checkpoint in memory-mapped mode + weights_only=True, + ) + llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint + + if model_config.vision_encoder is not None: + # Take the LLM weights (starting with "model.") from the VLM checkpoint + llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") + if model_config.vision_encoder is not None: + # For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']` + # and `checkpoint['mm_projector']` are both for those weights + # For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights + if "vision_encoder" in checkpoint: + log.debug("Using pretrained vision_encoder") + vit_checkpoint = checkpoint["vision_encoder"] + else: + log.debug("Using fine-tuned vision_encoder") + vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.") + vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.") + if "mm_projector" in checkpoint: + log.debug("Using pretrained mm_projector") + projector_checkpoint = checkpoint["mm_projector"] + else: + log.debug("Using fine-tuned mm_projector") + projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.") + projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.") + assert ( + len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0 + ), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector." + + tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) + orig_precision = torch.get_default_dtype() + precision = getattr(torch, model_config.precision) + torch.set_default_dtype(precision) + log.debug(f"Setting torch default dtype to {precision}") + + model = Transformer( + params=model_config, + model_parallel=model_parallel, + tokenizer_config=tokenizer_config, + ) + model_kwargs = {} + + if model_config.vision_encoder is not None: + assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided." + vit_config = get_vit_config(model_config.vision_encoder) + vit_config["tensor_model_parallel_size"] = tensor_parallel_size + vision_encoder = VisionTransformer.build( + vit_config, + ) + + mm_projector = MultimodalProjector( + mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"] + ) + model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector}) + + # Perform vocab expansion + if tokenizer.vocab_size > model.vocab_size: + log.debug(f"Expanding vocab size to {tokenizer.vocab_size}") + # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, + expand_output_layer = not (tokenizer.training_type == "text_to_video") + model.expand_vocab( + tokenizer.vocab_size, + init_method="gaussian", + expand_output_layer=expand_output_layer, + ) + + if shard_checkpoint: + # Shard the checkpoint according to tensor parallelism. + with misc.timer("sharding checkpoint according to tensor parallelism"): + if model_parallel is not None: + assert model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] + llm_checkpoint = obtain_tensor_parallel_state_dict( + llm_checkpoint, + tensor_parallel_size=tensor_parallel_size, + tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), + model_config=model_config, + ) + if model_config.vision_encoder is not None: + # Shard vision encoder and multimodal projector weights + vit_checkpoint = obtain_tensor_parallel_state_dict( + vit_checkpoint, + tensor_parallel_size=tensor_parallel_size, + tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), + model_config=vit_config, + ) + + # Remove the "model." prefix in the state_dict + llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") + with misc.timer("loading state_dict into model"): + missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) + # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) + missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + + if model_config.vision_encoder is not None: + vision_encoder.load_state_dict(vit_checkpoint) + mm_projector.load_state_dict(projector_checkpoint) + if model_config.vision_encoder_in_channels != 3: + vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels) + + model = model.to(precision) # ensure model parameters are in the correct precision + log.debug(f"Model config: {model_config}") + + model_class = AutoRegressiveModel + + torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value + + return model_class(model, tokenizer, model_config, **model_kwargs) + + @torch.no_grad() + def generate( + self, + prompt_tokens: List[List[int]] | torch.Tensor, + max_gen_len: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + num_gen_seq: int = 1, + logprobs: bool = False, + echo: bool = False, + seed: int = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + compile_sampling: bool = True, + compile_prefill: bool = False, + verbose: bool = True, + stop_tokens: Optional[Set[int]] = None, + images: Optional[torch.Tensor] = None, + ): + """ + Autoregressive generation built upon the gpt-fast implementation (https://github.com/pytorch-labs/gpt-fast). + + Args: + prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len). + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_k (int, optional): Top-k value for top-k sampling. Defaults to None. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. + num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + logit_clipping_range (list, optional): Range of logits to clip. Defaults to []. + seed (int, optional): Random seed for reproducibility. Defaults to None. + compile_sampling (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True. + compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. + verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False. + """ + assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified." + if temperature == 0: + top_p, top_k = None, None + log.debug("Setting top_p and top_k to None because temperature is 0") + if top_p is not None: + log.debug(f"Using top-p sampling with p={top_p} and temperature={temperature}") + elif top_k is not None: + log.debug(f"Using top-k sampling with k={top_k} and temperature={temperature}") + else: + log.debug("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") + + orig_precision = torch.get_default_dtype() + torch.set_default_dtype(self.precision) + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + # Experimental features to reduce compilation times, will be on by default in future + torch._inductor.config.fx_graph_cache = True + + if seed is not None: + misc.set_random_seed(seed) + + assert not logprobs, "logprobs are not supported for fast_generate yet" + # Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags + if compile_sampling and not getattr(self, "inference_decode_compiled", False): + log.info("Compiling AR sampling function. Note: the first run will be slower due to compilation") + self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + self.inference_decode_compiled = True + log.info("Compiled AR sampling function.") + if compile_prefill and not getattr(self, "inference_prefill_compiled", False): + log.info("Compiling prefill function. Note: the first run will be slower due to compilation") + self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + self.inference_prefill_compiled = True + log.info("Compiled prefill function.") + + if not hasattr(self, "decode_one_token"): + self.decode_one_token = decode_one_token + if not hasattr(self, "prefill"): + self.prefill = prefill + + # Initialization and Assertions + if isinstance(self.model.params, list): + # During training, model.params is a list + log.debug( + f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" + ) + params = self.config + else: + params = self.model.params + if isinstance(prompt_tokens, list): + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda") + if prompt_tokens.ndim == 1: + prompt_tokens = prompt_tokens.view(1, -1) + else: + assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}" + batch_size, prompt_len = prompt_tokens.shape + total_len = min(params.max_seq_len, max_gen_len + prompt_len) + if max_gen_len + prompt_len > params.max_seq_len: + log.warning( + f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}" + ) + max_gen_len = params.max_seq_len - prompt_len + + if context_mask is not None: + context_mask = context_mask.to(dtype=torch.bool) + if context_mask.ndim == 2: + assert ( + context_mask.shape[0] == batch_size + ), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}" + # Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len] + context_mask = context_mask.view(batch_size, 1, 1, -1) + + if num_gen_seq > 1: + assert ( + batch_size == 1 + ), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts" + log.debug(f"Generating {num_gen_seq} sequences with the same prompt") + assert ( + num_gen_seq <= params.max_batch_size + ), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}" + # repeat the prompt tokens for num_gen_seq times + prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1) + assert prompt_tokens.shape == ( + num_gen_seq, + prompt_len, + ), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}" + batch_size = len(prompt_tokens) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device) + empty[:, :prompt_len] = prompt_tokens + seq = empty + input_pos = torch.arange(0, prompt_len, device="cuda") + + if verbose: + prefill_start = time.time() + + if images is not None: + images = images.to(device=prompt_tokens.device, dtype=torch.bfloat16) + prompt_token_embeddings = self.embed_vision_language_features(prompt_tokens, images) + else: + prompt_token_embeddings = None + + if context is not None: + context = context.to(device=prompt_tokens.device, dtype=self.precision) + + # Prefill stage + next_token = self.prefill( + self.model, + input_pos=input_pos, + tokens=prompt_tokens if prompt_token_embeddings is None else None, + token_embeddings=prompt_token_embeddings, + temperature=temperature, + top_k=top_k, + top_p=top_p, + context=context, + context_mask=context_mask, + ) + if verbose: + prefill_time = time.time() - prefill_start + + seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype) + input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda") + stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens + stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda") + + if verbose: + decode_start = time.time() + # Decode stage + generated_tokens = decode_n_tokens( + self.model, + next_token.view(batch_size, -1), + input_pos, + max_gen_len - 1, + temperature=temperature, + top_k=top_k, + top_p=top_p, + stop_tokens=stop_tokens, + decode_one_token_function=self.decode_one_token, + context=context, + context_mask=context_mask, + ) + gen_len = len(generated_tokens) + if verbose: + decode_time = time.time() - decode_start + prefill_throughput = prompt_len / prefill_time + decode_throughput = gen_len / decode_time + log.debug(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s") + log.debug(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s") + + generated_tokens = torch.cat(generated_tokens, dim=1) + + log.debug(f"generated_tokens: {generated_tokens.shape}") + seq = seq[:, : prompt_len + 1 + gen_len] + seq[:, prompt_len + 1 :] = generated_tokens + if not echo: + seq = seq[:, prompt_len:] + + torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value + + return seq, None + + def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor: + """ + Embed vision and language features into a combined representation. + + Args: + input_ids (torch.Tensor): Input token IDs. + images (torch.tensor): Input images. + + Returns: + torch.Tensor: Combined vision-language features. + + Raises: + AssertionError: If vision encoder or mm projector is not initialized, + or if dimensions mismatch. + """ + # Ensure vision encoder and mm projector are initialized + assert self.vision_encoder is not None + assert self.mm_projector is not None + + # Get image token ID and validate it + image_token_id = self.vision_encoder.image_token_id + assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}" + + # Identify text and image locations in the input + text_locations = input_ids != image_token_id + image_locations = input_ids == image_token_id + + # Process text features + text_features = self.model.tok_embeddings(input_ids[text_locations]) + + # Process image features + images = images.to(device=text_features.device, dtype=text_features.dtype) + vit_outputs = self.vision_encoder(images) + image_features = self.mm_projector(vit_outputs) + + # Get dimensions + B, seq_len = input_ids.shape + N_total = B * seq_len + N_txt, D_txt = text_features.shape + N_img, N_patch, D_img = image_features.shape + + # Reshape image features + image_features = image_features.reshape(N_img * N_patch, D_img) + + # Validate dimensions + assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}" + assert ( + N_total == N_txt + N_img * N_patch + ), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}" + + # Combine text and image features + combined_features = torch.empty( + (B, seq_len, D_txt), + dtype=text_features.dtype, + device=text_features.device, + ) + combined_features[text_locations, :] = text_features + combined_features[image_locations, :] = image_features + + return combined_features + + def state_dict(self, *args, **kwargs): + """ + Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). + """ + state_dict = super().state_dict(*args, **kwargs) + return process_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + """ + Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by + TransformerEngine for FP8). + """ + state_dict = process_state_dict(state_dict) + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) + actual_missing_keys = [] + for key in missing_keys: + if not any(substring in key for substring in substrings_to_ignore): + actual_missing_keys.append(key) + if strict: + if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: + raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") + return _IncompatibleKeys(actual_missing_keys, unexpected_keys) diff --git a/cosmos_predict1/autoregressive/modules/__init__.py b/cosmos_predict1/autoregressive/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/modules/attention.py b/cosmos_predict1/autoregressive/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea6af4c39530aee882a58f5f956d1206a069562 --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/attention.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union + +import torch +from megatron.core import parallel_state +from torch import nn +from torch.distributed._functional_collectives import all_reduce + +from cosmos_predict1.autoregressive.modules.embedding import RotaryPositionEmbedding +from cosmos_predict1.autoregressive.modules.normalization import create_norm + + +class Attention(nn.Module): + """ + Attenion layer with KV cache. + """ + + def __init__( + self, + n_heads: int, + n_kv_heads: Union[int, None], + dim: int, + max_batch_size: int, + max_seq_len: int, + context_dim: Optional[int] = None, + use_qk_normalization: bool = False, + norm_type: str = "rmsnorm", + norm_eps: float = 1e-5, + causal_mask: Optional[bool] = True, + head_dim: Optional[int] = None, + fuse_qkv: bool = False, + precision: str = "bfloat16", + tensor_parallel_size: int = 1, + attn_type: str = "self", + ): + """ + Initializes the GQA module. + + Args: + n_heads (int): The number of attention heads. + n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads. + dim (int): The dimensionality of the input and output. + max_batch_size (int): The maximum batch size. + max_seq_len (int): The maximum sequence length. + context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None. + use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False. + norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm". + norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5. + tp_group (int, optional): The tensor parallel group. + causal_mask (bool, optional): Whether to use causal mask. Defaults to True. + head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads. + fuse_qkv (bool, optional): Whether to fuse QKV. Defaults to False. + precision (str, optional): The precision of the module. Defaults to "bfloat16". + tensor_parallel_size (int, optional): The tensor parallel size. Defaults to 1. + attn_type (str, optional): The type of attention. Defaults to "self". + """ + super().__init__() + assert attn_type in ["self", "cross", "full"], f"Invalid attention type: {attn_type}" + self.attn_type = attn_type + self.tp_size = tensor_parallel_size + context_dim = dim if context_dim is None else context_dim + + self.dim = dim + self.context_dim = context_dim + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_kv_heads = self.n_kv_heads // self.tp_size + self.n_local_heads = n_heads // self.tp_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads if head_dim is None else head_dim + self.causal_mask = causal_mask + self.fuse_qkv = fuse_qkv + self.precision = precision + + if fuse_qkv: + assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})" + self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim + self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False) + # Register hook to load fused QKV weights + self._register_load_state_dict_pre_hook(self.load_hook) + else: + self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False) + self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False) + + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len + + if self.attn_type == "self": + # Cache for key and value tensors + self.init_kv_cache() + + # QK normalization layers + if use_qk_normalization: + assert n_heads % self.tp_size == 0, "n_heads must be divisible by tensor_model_parallel_size" + assert self.n_kv_heads % self.tp_size == 0, "n_kv_heads must be divisible by tensor_model_parallel_size" + self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) + self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) + + self.use_qk_normalization = use_qk_normalization + + self.to(dtype=getattr(torch, self.precision)) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def init_kv_cache(self, dtype=None): + cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim) + if dtype is None: + dtype = getattr(torch, self.precision) + if self.attn_type == "self": + self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda() + self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda() + + def forward( + self, + x: torch.Tensor, + rope: RotaryPositionEmbedding, + input_pos: torch.Tensor, + mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ): + """ + Forward pass of GQA. + + Args: + x: The input tensor of shape (batch_size, seq_len, dim). + rope: The rotary positional embedding module. + input_pos: The starting position of the current sequence. + mask: The attention mask tensor. + context: The context tensor of shape (batch_size, context_len, dim). + + Returns: + The output tensor after applying GQA. + """ + bsz, seqlen, _ = x.shape + + # Use one single module to handle both self-attn and cross-attn + context = x if context is None else context + context_len = seqlen if context is None else context.shape[1] + + if self.fuse_qkv: + q_size = self.n_local_heads * self.head_dim + kv_size = self.n_local_kv_heads * self.head_dim + xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1) + else: + # Compute query, key, and value projections + xq, xk, xv = self.wq(x), self.wk(context), self.wv(context) + + # Reshape projections + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) + + # QK normalization + if self.use_qk_normalization: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + # Apply rotary positional embeddings to queries and keys + # Only apply RoPE to self-attention! + if self.attn_type in ["self", "full"]: + xq, xk = rope(xq, xk, input_pos, seqlen) + + xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) + # xq: (bs, n_local_heads, seqlen, head_dim) + # xk: (bs, n_kv_heads, cache_len + context_len, head_dim) + # xv: (bs, n_kv_heads, cache_len + context_len, head_dim) + if self.attn_type == "self": + # Update cache with current key and value tensors + assert input_pos is not None + self.cache_k[:bsz, :, input_pos] = xk + self.cache_v[:bsz, :, input_pos] = xv + keys, values = ( + self.cache_k[:bsz, :, :], + self.cache_v[:bsz, :, :], + ) + else: + keys, values = xk, xv + + # Repeat keys and values if necessary + keys = keys.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim) + values = values.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim) + + # For self-attention, `is_causal` should be set to False when KV cache is pre-computed and used, + # since the masking is handled outside this attention module. + # For cross-attention, it's always full-attn without causal mask + is_causal = False + output = scaled_dot_product_attention( + xq, + keys, + values, + head_dim=self.head_dim, + mask=mask, + is_causal=is_causal, + dropout_p=0.0, + ) + output = output.view(bsz, seqlen, -1) + output = self.wo(output) + if self.tp_size > 1: + output = all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output + + +def scaled_dot_product_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + head_dim: int, + mask: Optional[torch.Tensor] = None, + is_causal: Optional[bool] = None, + dropout_p: float = 0.0, +) -> torch.Tensor: + """ + PyTorch's native implementation of Flash Attention 2. + + If `is_causal` is given, then the causal attention mask is applied accordingly: + - If `is_causal` is True, the standard upper-left causal attention masking is applied. + - If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is + provided (i.e., `mask is not None`). + + If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied + based on the provided mask tensor: + - If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True, + leading to the standard upper-left causal attention masking. + - If an attention mask is given (i.e., `mask is not None`), the provided mask is used, + and `is_causal` is set to False. + + Args: + q (torch.Tensor): Query tensor + k (torch.Tensor): Key tensor + v (torch.Tensor): Value tensor + head_dim (int): Dimension of each attention head + mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None. + is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None. + dropout_p (float, optional): Dropout rate. Defaults to 0.0. + + Returns: + torch.Tensor: Output tensor after applying scaled dot-product attention + """ + scale = 1.0 / math.sqrt(head_dim) + if is_causal is None: + is_causal = mask is None + y = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=dropout_p, + scale=scale, + is_causal=is_causal, + ) + return y.transpose(1, 2).contiguous() diff --git a/cosmos_predict1/autoregressive/modules/embedding.py b/cosmos_predict1/autoregressive/modules/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b167b331ccf4b79edd1d95fecd20bb161c5115 --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/embedding.py @@ -0,0 +1,649 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +from einops import rearrange, repeat +from megatron.core import parallel_state + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def _rotate_half_te(x: torch.Tensor) -> torch.Tensor: + """ + change sign so the last dimension becomes [-odd, +even]. + Adopted from TransformerEngine. + Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py + """ + x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb_te( + t: torch.Tensor, + cos_freqs: torch.Tensor, + sin_freqs: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input tensor. + Adopted from TransformerEngine. + Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[b, s, h, d]`, on which + rotary positional embedding will be applied. + cos_freqs: torch.Tensor + Cosine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float', + sin_freqs: torch.Tensor + Sine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float', + """ + rot_dim = cos_freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * cos_freqs) + (_rotate_half_te(t) * sin_freqs) + output = torch.cat((t, t_pass), dim=-1) + return output + + +def get_pos_emb_on_this_cp_rank(pos_emb: torch.Tensor, seq_dim: int) -> torch.Tensor: + """ + Get the position embedding for the current context parallel rank. + + Args: + pos_emb (torch.Tensor): The position embedding tensor. + seq_dim (int): The sequence dimension to slice. + + Returns: + torch.Tensor: The position embedding tensor for the current rank. + """ + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(non_blocking=True) + pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :]) + pos_emb = pos_emb.index_select(seq_dim, cp_idx) + pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :]) + return pos_emb + + +def get_pos_emb_on_this_sptp_rank(pos_emb: torch.Tensor, seq_dim: int) -> torch.Tensor: + """ + Get the position embedding for the current tensor parallel rank (only used when sequence parallel is turned on) + + Args: + pos_emb (torch.Tensor): The position embedding tensor. + seq_dim (int): The sequence dimension to slice. + + Returns: + torch.Tensor: The position embedding tensor for the current rank. + """ + tp_size = parallel_state.get_tensor_model_parallel_world_size() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + pos_emb_chunks = torch.chunk(pos_emb, tp_size, dim=seq_dim) + pos_emb = pos_emb_chunks[tp_rank] + return pos_emb + + +class RotaryPositionEmbedding(torch.nn.Module): + """ + Rotary Position Embedding module as described in the paper: + https://arxiv.org/abs/2104.09864 + + This module implements rotary positional embeddings, which are used to + enhance the performance of transformer models. + + Args: + dim (int): Dimensionality of the input tensor. + max_position_embeddings (Optional[int]): Maximum position embeddings. + original_max_position_embeddings (Optional[int]): Original maximum position embeddings. + rope_theta (Optional[float]): Base for the frequency calculation. + apply_yarn (Optional[bool]): Whether to apply YaRN (Yet another Rotary). + scale (Optional[int]): Scaling factor for the frequency calculation. + extrapolation_factor (Optional[int]): Extrapolation factor for the frequency extension. + attn_factor (Optional[int]): Attention factor for the frequency calculation. + beta_fast (Optional[int]): Fast beta value for the YaRN frequency calculation. + beta_slow (Optional[int]): Slow beta value for the YaRN frequency calculation. + rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D". + latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs. + original_latent_shape (Optional[List[int]]): Original shape of the latent tensor for video or image inputs. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: Optional[int] = None, + original_max_position_embeddings: Optional[int] = None, + rope_theta: Optional[float] = 10000.0, + apply_yarn: Optional[bool] = False, + scale: Optional[int] = None, + extrapolation_factor: Optional[int] = 1, + attn_factor: Optional[int] = 1, + beta_fast: Optional[int] = 32, + beta_slow: Optional[int] = 1, + rope_dim: Optional[str] = "1D", + latent_shape: Optional[List[int]] = None, + original_latent_shape: Optional[List[int]] = None, + pad_to_multiple_of: Optional[int] = None, + ): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.rope_theta = rope_theta + self.apply_yarn = apply_yarn + self.scale = scale + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = 1.0 + self.rope_dim = rope_dim + self.latent_shape = latent_shape + self.original_latent_shape = original_latent_shape + self.pad_to_multiple_of = pad_to_multiple_of + self.get_inv_freq(torch.cuda.current_device()) + + def get_mscale(self, scale: float = 1.0) -> float: + """Get the magnitude scaling factor for YaRN.""" + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + def forward(self, seq_len: Optional[int] = None) -> torch.Tensor: + """ + Forward pass for the rotary position embedding. + + Args: + seq_len (Optional[int]): Length of the sequence. + + Returns: + torch.Tensor: The computed frequencies for positional embedding. + """ + + if self.apply_yarn and seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + self.freqs = self.compute_freqs() + + return self.freqs + + def compute_freqs( + self, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute the spatial frequencies for the latent tensor.""" + self.seq = torch.arange(self.max_seq_len_cached, dtype=torch.float).cuda() + if self.rope_dim == "1D": + emb = torch.einsum("i,j->ij", self.seq, self.inv_freq) + + elif self.rope_dim == "2D": + H, W = self.latent_shape + half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq) + half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq) + emb = torch.cat( + [ + repeat(half_emb_h, "h d -> h w d", w=W), + repeat(half_emb_w, "w d -> h w d", h=H), + ] + * 2, + dim=-1, + ) + emb = rearrange(emb, "h w d -> (h w) 1 1 d").float() + + elif self.rope_dim == "3D": + T, H, W = self.latent_shape + half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq) + half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq) + half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq) + emb = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float() + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + return emb + + def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor: + """Get the scale factors for YaRN.""" + # Calculate the high and low frequency cutoffs for YaRN. Note: `beta_fast` and `beta_slow` are called + # `high_freq_factor` and `low_freq_factor` in the Llama 3.1 RoPE scaling code. + high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len + low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len + # Obtain a smooth mask that has a value of 0 for low frequencies and 1 for high frequencies, with linear + # interpolation in between. + smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1) + # For low frequencies, we scale the frequency by 1/self.scale. For high frequencies, we keep the frequency. + scale_factors = (1 - smooth_mask) / self.scale + smooth_mask + return scale_factors + + def get_inv_freq(self, device: torch.device) -> None: + """Get the inverse frequency.""" + if self.rope_dim == "1D": + assert self.max_position_embeddings is not None, "Max position embeddings required." + inv_freq = 1.0 / ( + self.rope_theta ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim) + ) + if self.apply_yarn: + assert self.original_max_position_embeddings is not None, "Original max position embeddings required." + assert self.beta_slow is not None, "Beta slow value required." + assert self.beta_fast is not None, "Beta fast value required." + + scale_factors = self.get_scale_factors(inv_freq, self.original_max_position_embeddings) + # Apply the scaling factors to inv_freq. + inv_freq = inv_freq * scale_factors + # Set the magnitude scaling factor. + self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) + self.max_seq_len_cached = self.max_position_embeddings + self.inv_freq = inv_freq + + elif self.rope_dim == "2D": + assert self.latent_shape is not None, "Latent shape required." + dim_h = self.dim // 2 + spatial_inv_freq = 1.0 / ( + self.rope_theta ** torch.arange(0, dim_h, 2, dtype=torch.float32, device=device) / dim_h + ) + if self.apply_yarn: + assert self.original_latent_shape is not None, "Original latent shape required." + assert self.beta_slow is not None, "Beta slow value required." + assert self.beta_fast is not None, "Beta fast value required." + + scale_factors = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[0]) + spatial_inv_freq = spatial_inv_freq * scale_factors + self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) + self.spatial_inv_freq = spatial_inv_freq + self.max_seq_len_cached = max(self.latent_shape) + + elif self.rope_dim == "3D": + assert self.latent_shape is not None, "Latent shape required." + dim_h = self.dim // 6 * 2 + dim_t = self.dim - 2 * dim_h + self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(device) / dim_h + spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range) + self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(device) / dim_t + temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range) + if self.apply_yarn: + assert self.original_latent_shape is not None, "Original latent shape required." + assert self.beta_slow is not None, "Beta slow value required." + assert self.beta_fast is not None, "Beta fast value required." + scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1]) + spatial_inv_freq = spatial_inv_freq * scale_factors_spatial + scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0]) + temporal_inv_freq = temporal_inv_freq * scale_factors_temporal + self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) + self.spatial_inv_freq = spatial_inv_freq + self.temporal_inv_freq = temporal_inv_freq + self.max_seq_len_cached = max(self.latent_shape) + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + + self.freqs = self.compute_freqs() + + +class RotaryPositionEmbeddingTE(RotaryPositionEmbedding): + """ + Rotary Position Embedding with context parallelism support. + + """ + + def __init__( + self, + **kwargs, + ): + super().__init__( + **kwargs, + ) + + def forward(self, seq_len: int, training_type: str = None) -> torch.Tensor: + """ + Create rotary position embedding frequencies. + + Args: + seq_len (int): Sequence length of a sample. + + Returns: + torch.Tensor: The computed positional embeddings. + """ + if self.rope_dim == "1D": + freqs = super().forward(seq_len=seq_len) + emb = torch.cat((freqs, freqs), dim=-1) + emb = emb.reshape(emb.size(0), 1, 1, emb.size(1)) + + elif self.rope_dim in ["2D", "3D"]: + emb = super().forward(seq_len=seq_len) + if training_type == "text_to_video": + # since we added token at the beginning of the video for text2video, we also extend the position embedding by one token in the beginning + bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device) + emb = torch.cat((bov_pe, emb), dim=0) + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: + # Round up to the nearest multiple of pad_to_multiple_of + pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of + emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device)), dim=0) + + return emb + + +class RotaryPositionEmbeddingPytorch(RotaryPositionEmbedding): + """ + Rotary Position Embedding with PyTorch specific adjustments. + + """ + + def __init__( + self, + **kwargs, + ): + super().__init__( + **kwargs, + ) + if self.rope_dim == "1D": + emb = torch.stack((self.freqs, self.freqs), dim=-1).reshape(*self.freqs.shape[:-1], -1) + elif self.rope_dim in ["2D", "3D"]: + emb = rearrange(self.freqs, "s 1 1 d -> s d").float() + self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, :, None, :], persistent=False) + self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, :, None, :], persistent=False) + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dimensions of the input tensor.""" + x_reshaped = x.reshape(*x.shape[:-1], -1, 2) + x1 = x_reshaped[..., 0] + x2 = x_reshaped[..., 1] + output = torch.stack((-x2, x1), dim=-1).reshape(*x.shape) + return output + + def forward( + self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the rotary position embedding. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + input_pos (Optional[torch.Tensor]): Starting position for the sequence. + seq_len (Optional[int]): Length of the sequence. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. + """ + if self.apply_yarn and seq_len > self.max_seq_len_cached: + freqs = super().forward(seq_len) + if self.rope_dim == "1D": + emb = torch.stack((freqs, freqs), dim=-1).reshape(*freqs.shape[:-1], -1) + elif self.rope_dim in ["2D", "3D"]: + emb = rearrange(freqs, "s 1 1 d -> s d").float() + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + self.register_buffer( + "cos_cached", (emb.cos() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False + ) + + if input_pos is not None: + cos_cached = self.cos_cached[:, input_pos] + sin_cached = self.sin_cached[:, input_pos] + else: + assert ( + self.cos_cached.shape[1] >= seq_len + ), f"Invalid sequence length; cos_cached.shape {self.cos_cached.shape}, seq_len {seq_len}." + cos_cached = self.cos_cached[:, :seq_len, ...] + sin_cached = self.sin_cached[:, :seq_len, ...] + xq = q * cos_cached + self.rotate_half(q) * sin_cached + xk = k * cos_cached + self.rotate_half(k) * sin_cached + + return xq.type_as(q), xk.type_as(k) + + +class RotaryPositionEmbeddingPytorchV2(RotaryPositionEmbedding): + """ + Rotary Position Embedding that works in the same way as the TransformerEngine RoPE + (https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) + + """ + + def __init__( + self, + seq_len: int, + training_type: str = None, + **kwargs, + ): + super().__init__( + **kwargs, + ) + emb = self.create_rope_freqs(seq_len=seq_len, training_type=training_type) + emb = emb.transpose(0, 1).contiguous() # [seq, 1, 1, dim] -> [1, seq, 1, dim] + assert emb.shape[0] == 1 and emb.shape[2] == 1, f"emb shape: {emb.shape}" + # cos/sin first then dtype conversion for better precision + self.register_buffer("cos_cached", torch.cos(emb), persistent=False) + self.register_buffer("sin_cached", torch.sin(emb), persistent=False) + + def create_rope_freqs(self, seq_len: int, training_type: str = None) -> torch.Tensor: + """ + Create rotary position embedding frequencies. + + Args: + seq_len (int): Sequence length of a sample. + + Returns: + torch.Tensor: The computed positional embeddings. + """ + if self.rope_dim == "1D": + freqs = super().forward(seq_len=seq_len) + emb = torch.cat((freqs, freqs), dim=-1) + emb = emb.reshape(emb.size(0), 1, 1, emb.size(1)) + + elif self.rope_dim in ["2D", "3D"]: + emb = super().forward(seq_len=seq_len) + if training_type == "text_to_video": + # since we added token at the beginning of the video for text2world, we also extend the position embedding by one token in the beginning + bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device) + emb = torch.cat((bov_pe, emb), dim=0) + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: + # Round up to the nearest multiple of pad_to_multiple_of + pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of + emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device)), dim=0) + + return emb + + def forward( + self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + if q.dtype != self.cos_cached.dtype: + self.cos_cached = self.cos_cached.to(q.dtype) + self.sin_cached = self.sin_cached.to(q.dtype) + + cos_emb = self.cos_cached + sin_emb = self.sin_cached + if input_pos is not None: + cos_emb = cos_emb[:, input_pos, :, :] + sin_emb = sin_emb[:, input_pos, :, :] + elif seq_len is not None: + cos_emb = cos_emb[:, :seq_len, :, :] + sin_emb = sin_emb[:, :seq_len, :, :] + q = _apply_rotary_pos_emb_te(q, cos_emb, sin_emb) + k = _apply_rotary_pos_emb_te(k, cos_emb, sin_emb) + return q, k + + +class RotaryPositionEmbeddingPytorchV1(RotaryPositionEmbedding): + """ + Rotary Position Embedding that works in the same way as + mistral_inference (https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/rope.py) + or llama3 (https://github.com/meta-llama/llama3/blob/main/llama/model.py) + + """ + + def __init__( + self, + **kwargs, + ): + super().__init__( + **kwargs, + ) + if self.rope_dim == "1D": + emb = torch.stack((self.freqs, self.freqs), dim=-1).reshape(*self.freqs.shape[:-1], -1) + elif self.rope_dim in ["2D", "3D"]: + emb = rearrange(self.freqs, "s 1 1 d -> s d").float() + self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, :, None, :], persistent=False) + self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, :, None, :], persistent=False) + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dimensions of the input tensor.""" + x_reshaped = x.reshape(*x.shape[:-1], -1, 2) + x1 = x_reshaped[..., 0] + x2 = x_reshaped[..., 1] + output = torch.stack((-x2, x1), dim=-1).reshape(*x.shape) + return output + + def forward( + self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the rotary position embedding. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + input_pos (Optional[torch.Tensor]): Starting position for the sequence. + seq_len (Optional[int]): Length of the sequence. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. + """ + if self.apply_yarn and seq_len > self.max_seq_len_cached: + freqs = super().forward(seq_len) + if self.rope_dim == "1D": + emb = torch.stack((freqs, freqs), dim=-1).reshape(*freqs.shape[:-1], -1) + elif self.rope_dim in ["2D", "3D"]: + emb = rearrange(freqs, "s 1 1 d -> s d").float() + else: + raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}") + self.register_buffer( + "cos_cached", (emb.cos() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False + ) + + if input_pos is not None: + cos_cached = self.cos_cached[:, input_pos] + sin_cached = self.sin_cached[:, input_pos] + else: + assert ( + self.cos_cached.shape[1] >= seq_len + ), f"Invalid sequence length; cos_cached.shape {self.cos_cached.shape}, seq_len {seq_len}." + cos_cached = self.cos_cached[:, :seq_len, ...] + sin_cached = self.sin_cached[:, :seq_len, ...] + xq = q * cos_cached + self.rotate_half(q) * sin_cached + xk = k * cos_cached + self.rotate_half(k) * sin_cached + + return xq.type_as(q), xk.type_as(k) + + +class SinCosPosEmbAxisTE(torch.nn.Module): + def __init__( + self, + dim: int, + latent_shape: Optional[List[int]] = None, + pad_to_multiple_of: Optional[int] = None, + dtype: torch.dtype = torch.bfloat16, + device="cuda", + **kwargs, + ): + """ + Args: + dim (int): Dimensionality of the input tensor. + latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs. + pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value. + dtype (torch.dtype): Data type of the position embedding tensor. + """ + super().__init__() + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.latent_shape = latent_shape + T, H, W = latent_shape + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(H)) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(W)) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(T)) + + self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).to(dtype=dtype, device=device), persistent=False) + self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).to(dtype=dtype, device=device), persistent=False) + self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).to(dtype=dtype, device=device), persistent=False) + self.pad_to_multiple_of = pad_to_multiple_of + + def forward( + self, + training_type: str | None = None, + ) -> torch.Tensor: + T, H, W = self.latent_shape + emb = torch.cat( + [ + repeat(self.pos_emb_t, "t d-> t h w d", h=H, w=W), + repeat(self.pos_emb_h, "h d-> t h w d", t=T, w=W), + repeat(self.pos_emb_w, "w d-> t h w d", t=T, h=H), + ], + dim=-1, + ) + # Flatten the T,H,W dimensions + emb = rearrange(emb, "t h w d -> (t h w) d") + + if training_type == "text_to_video": + bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device, dtype=emb.dtype) + emb = torch.cat((bov_pe, emb), dim=0) + if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0: + pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of + emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device, dtype=emb.dtype)), dim=0) + seq_len, dim = emb.shape + emb = emb.reshape(1, seq_len, dim) + return emb diff --git a/cosmos_predict1/autoregressive/modules/linear.py b/cosmos_predict1/autoregressive/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..cce025a8a3c67037865791202f4ad05ec16673c5 --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/linear.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import torch +from megatron.core import ModelParallelConfig, parallel_state +from megatron.core.tensor_parallel import ColumnParallelLinear as McoreColumnParallelLinear +from megatron.core.tensor_parallel import RowParallelLinear as McoreRowParallelLinear +from megatron.core.tensor_parallel import VocabParallelEmbedding as McoreVocabParallelEmbedding +from megatron.core.tensor_parallel.mappings import ( + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +from megatron.core.tensor_parallel.utils import VocabUtility +from torch.distributed import _functional_collectives as funcol +from torch.distributed._functional_collectives import all_reduce + + +class VocabParallelEmbedding(torch.nn.Module): + """ + Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + + Args: + num_embeddings (int): vocabulary size. + embedding_dim (int): size of hidden state. + precision (str): precision of the embedding. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + precision: str = "bfloat16", + ): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + # Divide the weight matrix along the vocaburaly dimension. + (self.vocab_start_index, self.vocab_end_index) = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, + parallel_state.get_tensor_model_parallel_rank(), + self.tensor_model_parallel_size, + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index + + self.weight = torch.nn.Parameter( + torch.empty( + self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=getattr(torch, precision), + ) + ) + + def forward(self, input_): + """Forward. + + Args: + input_ (torch.Tensor): Input tensor. + """ + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output = self.weight[masked_input] + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output[input_mask, :] = 0.0 + + output = all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output + + +class ColumnParallelLinear(McoreColumnParallelLinear): + """ + A modified version of Mcore's ColumnParallelLinear that only returns the output tensor. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input_: torch.Tensor): + """ + Performs the forward pass of the column parallel linear layer. + + Args: + input_ (torch.Tensor): The input tensor. + weight (Optional[torch.Tensor], optional): The weight tensor. If None, uses the layer's own weight. + + Returns: + torch.Tensor: The output tensor after the linear transformation. + """ + output, _ = super().forward(input_) + return output + + +class RowParallelLinear(McoreRowParallelLinear): + """ + A modified version of Mcore's RowParallelLinear that only returns the output tensor. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input_: torch.Tensor): + """ + Performs the forward pass of the Row Parallel linear layer. + + Args: + input_ (torch.Tensor): The input tensor. + weight (Optional[torch.Tensor], optional): The weight tensor. If None, uses the layer's own weight. + + Returns: + torch.Tensor: The output tensor after the linear transformation. + """ + output, _ = super().forward(input_) + return output + + +class TrainingVocabParallelEmbedding(McoreVocabParallelEmbedding): + """ + Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + + Args: + num_embeddings (int): vocabulary size. + embedding_dim (int): size of hidden state. + + Keyword Args: + sequence_parallel (bool): Decides whether to perform ReduceScatter after embedding lookup + batch_first (bool): If True, then output tensor shape is [batch, seq, feature]. If False, then shape becomes + [seq, batch, feature]. Note: We assume the input tensor is always in the shape of [seq, batch]. + config: A megatron.core.ModelParallelConfig object + use_inference_allreduce (bool): If True, then Megatron's allreduce in the forward pass is disabled, and the pytorch's + allreduce is used instead (inference mode only). + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + *, + init_method: Callable, + sequence_parallel: bool = False, + batch_first: bool = False, + config: ModelParallelConfig, + use_inference_allreduce: bool = False, + ): + super(TrainingVocabParallelEmbedding, self).__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + init_method=init_method, + config=config, + ) + self.sequence_parallel = sequence_parallel + if sequence_parallel: + # If sequence parallel, then the output tensor should be in the shape of [seq, batch, feature] + batch_first = False + self.batch_first = batch_first + self.use_inference_allreduce = use_inference_allreduce + + def forward(self, input_): + """Forward. + + Args: + input_ (torch.Tensor): Input tensor. + """ + if self.tensor_model_parallel_size > 1: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output = self.weight[masked_input] + # Mask the output embedding. + if self.tensor_model_parallel_size > 1: + output[input_mask, :] = 0.0 + + if self.sequence_parallel: + assert not self.batch_first + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + output = output.transpose(0, 1).contiguous() + if not self.use_inference_allreduce: + output = reduce_scatter_to_sequence_parallel_region(output) + else: + # Reduce across all the model parallel GPUs. + if not self.use_inference_allreduce: + output = reduce_from_tensor_model_parallel_region(output) + if not self.batch_first: + # Shape: [b, s, h] --> [s, b, h] + output = output.transpose(0, 1).contiguous() + + if self.use_inference_allreduce: + output = funcol.all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output diff --git a/cosmos_predict1/autoregressive/modules/mlp.py b/cosmos_predict1/autoregressive/modules/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..61ef18a8049d65760f3c2fdaaa8a21845d706cfa --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/mlp.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from megatron.core import ModelParallelConfig, parallel_state +from torch.distributed import _functional_collectives as funcol +from torch.distributed._functional_collectives import all_reduce + +from cosmos_predict1.autoregressive.modules.linear import ColumnParallelLinear, RowParallelLinear + + +def compute_llama3_ffn_hidden_dim(dim: int, multiple_of: int, ffn_dim_multiplier: float) -> int: + """ + Computes the feedforward network dimensionality. + + Args: + dim (int): The embedding dimensionality. + multiple_of (int): The multiple to round up the hidden dimensionality. + ffn_dim_multiplier (float): The multiplier for the hidden dimensionality. + + Returns: + The feedforward network dimensionality. + """ + hidden_dim = 4 * dim + hidden_dim = int(2 * hidden_dim / 3) # custom dim factor + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + # Round up hidden dimensionality to the nearest multiple + return multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + +class MLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + tensor_parallel_size: int = 1, + ): + """ + Initializes the multilayer perceptron (MLP) module. + + Args: + dim: The input and output dimensionality. + hidden_dim: The dimensionality of the hidden layer. + """ + super().__init__() + self.tp_size = tensor_parallel_size + self.w1 = nn.Linear(dim, hidden_dim // self.tp_size, bias=False) + self.w2 = nn.Linear(hidden_dim // self.tp_size, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim // self.tp_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs the forward pass of the MLP module. + + Args: + x: The input tensor of shape (batch_size, dim). + + Returns: + The output tensor of shape (batch_size, dim). + """ + output = self.w2(F.silu(self.w1(x)) * self.w3(x)) + if self.tp_size > 1: + output = all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output + + +class TrainingMLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + hidden_dropout: float = 0.0, + set_parallel_mode: bool = False, + model_parallel: Optional[ModelParallelConfig] = None, + inference: bool = False, + ): + """ + Initializes the multilayer perceptron (MLP) module. + + Args: + dim: The input and output dimensionality. + hidden_dim: The dimensionality of the hidden layer. + hidden_dropout: Dropout after the attention and feed-forward layers (following TransformerEngine's + implementation in its TransformerLayer class). + set_parallel_mode: Whether to use column and row parallel linear layers. + model_parallel: The model parallel configuration. + inference: Whether the model is used for inference. + """ + super().__init__() + self.hidden_dropout = hidden_dropout + if model_parallel and model_parallel.tensor_model_parallel_size > 1: + self.tp_size = model_parallel.tensor_model_parallel_size + else: + self.tp_size = 1 + if set_parallel_mode and not inference: + kwargs = {"bias": False, "init_method": lambda x: x, "config": model_parallel} + # Using column and row parallel linear layers + self.w1 = ColumnParallelLinear(dim, hidden_dim, gather_output=False, **kwargs) + self.w2 = RowParallelLinear(hidden_dim, dim, input_is_parallel=True, skip_bias_add=True, **kwargs) + self.w3 = ColumnParallelLinear(dim, hidden_dim, gather_output=False, **kwargs) + else: + self.w1 = nn.Linear(dim, hidden_dim // self.tp_size, bias=False) + self.w2 = nn.Linear(hidden_dim // self.tp_size, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim // self.tp_size, bias=False) + + self.inference = inference + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs the forward pass of the MLP module. + + Args: + x: The input tensor of shape (batch_size, dim). + + Returns: + The output tensor of shape (batch_size, dim). + """ + x = F.dropout(x, p=self.hidden_dropout, training=self.training) + output = self.w2(F.silu(self.w1(x)) * self.w3(x)) + output = F.dropout(output, p=self.hidden_dropout, training=self.training) + + if self.inference and self.tp_size > 1: + output = funcol.all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output + + def init_weights(self, init_std: float): + """ + Initializes the weights of the MLP module. + """ + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) diff --git a/cosmos_predict1/autoregressive/modules/mm_projector.py b/cosmos_predict1/autoregressive/modules/mm_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..ee54c961498ff108a92fe621e9322649f7ad891b --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/mm_projector.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multimodal projector to connect vision encoder / tokenizer with the LLM.""" + +from typing import Any, Optional + +import torch +import torch.nn as nn + + +class DownSampleBlock(nn.Module): + """Downsample block.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + """ + Performs the forward pass of the downsample block. + + Args: + x (torch.Tensor): The input tensor from ViT's output of a sequence of embeddings. + Shape: (b, seq_len, c). + + Returns: + torch.Tensor: The output tensor. Shape: (b, seq_len/4, c*4). + """ + vit_embeds = x + # Get h and w as the sqrt of seq length. This assumes that the input is square-shaped. + h = w = int(vit_embeds.shape[1] ** 0.5) + b = vit_embeds.shape[0] + vit_embeds = vit_embeds.reshape(b, h, w, -1) + vit_embeds = self.flat_square(vit_embeds) + vit_embeds = vit_embeds.reshape(b, -1, vit_embeds.shape[-1]) + return vit_embeds + + def flat_square(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs spatial downsampling while increasing the number of channels. + + Args: + x (torch.Tensor): The input tensor reshaped to a 2D grid. + Shape: (b, h, w, c) + + Returns: + torch.Tensor: The output tensor after the spatial downsampling. + Shape: (b, h/2, w/2, c*4) + """ + b, h, w, c = x.size() + # If w or h is odd, pad a column or a row of zeros. + if h % 2 == 1: + x = torch.concat([x, torch.zeros((b, 1, w, c), dtype=x.dtype).to(x.device)], dim=1).contiguous() + b, h, w, c = x.size() + if w % 2 == 1: + x = torch.concat([x, torch.zeros((b, h, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous() + b, h, w, c = x.size() + # 2x spatial downsampling, 4x channel increasing. + x = x.view(b, h, int(w / 2), int(c * 2)) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(b, int(h / 2), int(w / 2), int(c * 4)) + x = x.permute(0, 2, 1, 3).contiguous() + return x + + +class MultimodalProjector(nn.Module): + """Multimodal projector.""" + + def __init__( + self, + mm_projector_type: str, + in_dim: int, + out_dim: Optional[int] = None, + **kwargs: Any, + ): + super().__init__() + if out_dim is None: + out_dim = in_dim + if mm_projector_type == "identity": + self.projector = nn.Identity() + elif mm_projector_type == "linear": + self.projector = nn.Linear(in_dim, out_dim) + elif mm_projector_type == "mlp": + self.projector = nn.Sequential(nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)) + elif mm_projector_type == "mlp_downsample": + self.projector = nn.Sequential( + DownSampleBlock(), + nn.LayerNorm(in_dim * 4), + nn.Linear(in_dim * 4, out_dim), + nn.GELU(), + nn.Linear(out_dim, out_dim), + ) + else: + raise ValueError(f"Unknown projector type: {mm_projector_type}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.projector(x) diff --git a/cosmos_predict1/autoregressive/modules/normalization.py b/cosmos_predict1/autoregressive/modules/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..37af6f2f63ae7aa1bbc37a5c815226ceebf4ccbb --- /dev/null +++ b/cosmos_predict1/autoregressive/modules/normalization.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +def create_norm(norm_type: str, dim: int, eps: float = 1e-6): + """ + Creates the specified normalization layer based on the norm_type. + Adopted from TorchTriton: https://github.com/pytorch/torchtitan/blob/main/torchtitan/cosmos_predict1/norms.py + + Args: + norm_type (str): The type of normalization layer to create. + Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm + dim (int): The dimension of the normalization layer. + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. + + Returns: + The created normalization layer. + + Raises: + NotImplementedError: If an unknown norm_type is provided. + """ + norm_type = norm_type.lower() # Normalize to lowercase + + if norm_type == "layernorm": + return nn.LayerNorm(dim, eps=eps, bias=False) + elif norm_type == "np_layernorm": + return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + elif norm_type == "rmsnorm": + return RMSNorm(dim, eps=eps, compile=False) + elif norm_type == "compiled_rmsnorm": + return RMSNorm(dim, eps=eps, compile=True) + elif norm_type == "fused_rmsnorm": + raise NotImplementedError("Fused RMSNorm is not supported yet.") + else: + raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") + + +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + Reference implementation: https://github.com/pytorch/torchtitan/blob/main/torchtitan/cosmos_predict1/norms.py + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + compile (bool, optional): Whether to compile the forward function. Default is False. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.rmsnorm_fn = torch.compile(self.compute_rmsnorm, fullgraph=True) if compile else self.compute_rmsnorm + + @staticmethod + def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float): + def _norm(x, eps): + # Computes the root-mean-square norm of the input tensor. + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + output = _norm(x.float(), eps).type_as(x) + return output * weight + + def forward(self, x: torch.Tensor): + return self.rmsnorm_fn(x, self.weight, self.eps) + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) diff --git a/cosmos_predict1/autoregressive/networks/transformer.py b/cosmos_predict1/autoregressive/networks/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0f573636e2ce27f4f337fd02c0c60de208808664 --- /dev/null +++ b/cosmos_predict1/autoregressive/networks/transformer.py @@ -0,0 +1,519 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from megatron.core import parallel_state +from torch.distributed import broadcast, get_process_group_ranks +from torch.distributed._functional_collectives import all_gather_tensor +from torch.nn.modules.module import _IncompatibleKeys + +from cosmos_predict1.autoregressive.modules.attention import Attention +from cosmos_predict1.autoregressive.modules.embedding import ( + RotaryPositionEmbeddingPytorchV1, + RotaryPositionEmbeddingPytorchV2, + SinCosPosEmbAxisTE, +) +from cosmos_predict1.autoregressive.modules.linear import VocabParallelEmbedding +from cosmos_predict1.autoregressive.modules.mlp import MLP +from cosmos_predict1.autoregressive.modules.normalization import create_norm +from cosmos_predict1.autoregressive.utils.checkpoint import process_state_dict, substrings_to_ignore +from cosmos_predict1.autoregressive.utils.misc import maybe_convert_to_namespace +from cosmos_predict1.utils import log + + +class TransformerBlock(nn.Module): + """ + A single transformer block consisting of an attention layer and a feed-forward layer. + """ + + def __init__(self, layer_id: int, args=None): + """ + Initializes the TransformerBlock module. + + Args: + layer_id: The ID of the transformer block. + args: The model arguments containing hyperparameters. + """ + super().__init__() + args = maybe_convert_to_namespace(args) + attention_args = { + "n_heads": args["n_heads"], + "n_kv_heads": args["n_kv_heads"], + "dim": args["dim"], + "context_dim": None, + "max_batch_size": args["max_batch_size"], + "max_seq_len": args["max_seq_len"], + "use_qk_normalization": args["use_qk_normalization"], + "causal_mask": args["causal_mask"], + "head_dim": args["head_dim"], + "fuse_qkv": getattr(args, "fuse_qkv", False), + "precision": getattr(args, "precision", "bfloat16"), + "tensor_parallel_size": args["tensor_model_parallel_size"], + "attn_type": getattr(args, "attn_type", "self"), + } + self.attention = Attention(**attention_args) + + self.has_cross_attention = False + self.cross_attention, self.cross_attention_norm = None, None + + if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0: + self.has_cross_attention = True + cross_attention_args = attention_args.copy() + cross_attention_args.update({"context_dim": args["context_dim"], "fuse_qkv": False, "attn_type": "cross"}) + self.cross_attention = Attention(**cross_attention_args) + self.cross_attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + + self.feed_forward = MLP( + dim=args["dim"], + hidden_dim=args["ffn_hidden_size"], + tensor_parallel_size=args["tensor_model_parallel_size"], + ) + self.layer_id = layer_id + self.attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + self.ffn_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + + def forward( + self, + x: torch.Tensor, + rope: RotaryPositionEmbeddingPytorchV2, + input_pos: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the TransformerBlock module. + + Args: + x: The input tensor. + input_pos: The position of the current sequence. Used in inference (with KV cache) only. + freqs_cis: The precomputed frequency values for rotary position embeddings. + mask: The attention mask tensor. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + The output tensor after applying the transformer block. + """ + # Apply attention and residual connection + h = x + self.attention(self.attention_norm(x), rope=rope, input_pos=input_pos, mask=mask) + + # If insert cross-attention, apply CA and residual connection + if self.has_cross_attention: + h = h + self.cross_attention( + self.cross_attention_norm(h), rope=rope, input_pos=input_pos, mask=context_mask, context=context + ) + + # Apply feed-forward network and residual connection + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + """ + Initializes the weights of the transformer block. + """ + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + if self.has_cross_attention: + self.cross_attention_norm.reset_parameters() + self.cross_attention.init_weights(self.weight_init_std) + # zero-init the final output layer of cross-attention + # nn.init.zeros_(self.cross_attention.wo.weight) + + +class Transformer(nn.Module): + """ + The Transformer network consisting of transformer blocks. + """ + + def __init__(self, params, model_parallel=None, tokenizer_config=None, init_weights: bool = True): + """ + Initializes the Transformer module. + + Args: + params: The model parameters containing hyperparameters. + model_parallel: The model parallel configuration. + tokenizer_config: The model tokenizer configuration. + init_weights (bool): Whether to initialize the weights of the transformer following + TorchTitan's Llama3 initialization scheme. + """ + super().__init__() + # Check if self.params is an OmegaConf DictConfig instance + self.params = maybe_convert_to_namespace(params) + self.vocab_size = params["vocab_size"] + self.n_layers = params["n_layers"] + self.precision = getattr(torch, params["precision"]) + self.tokenizer_config = tokenizer_config + self.model_parallel = model_parallel + self.num_video_frames = params["num_video_frames"] + tp_group = self._get_tp_group() + + # Token embeddings + self.tok_embeddings = self._create_token_embeddings(self.model_parallel) + self.rope_config = self._create_rope_config() + + # Transformer layers + self.layers = nn.ModuleList( + [TransformerBlock(layer_id, self.params).to(self.precision) for layer_id in range(self.n_layers)] + ) + + # Final layer normalization + self.norm = create_norm(self.params["norm_type"], dim=self.params["dim"], eps=self.params["norm_eps"]).to( + self.precision + ) + if self.params["pytorch_rope_version"] == "v1": + self.rope = RotaryPositionEmbeddingPytorchV1(**self.rope_config) + elif self.params["pytorch_rope_version"] == "v2": + # Rotary position embeddings + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + self.rope = RotaryPositionEmbeddingPytorchV2( + seq_len=self.params["max_seq_len"], training_type=training_type, **self.rope_config + ) + self._broadcast_pos_emb(self.rope.cos_cached, tp_group=self._get_tp_group()) + self._broadcast_pos_emb(self.rope.sin_cached, tp_group=self._get_tp_group()) + else: + raise ValueError(f"Invalid PyTorch RoPE version: {self.params['pytorch_rope_version']}") + # Causal mask + self.causal_mask = torch.tril( + torch.ones(self.params["max_seq_len"], self.params["max_seq_len"], dtype=torch.bool) + ).cuda() + + # Output projection + self.output = self._create_output_projection() + + # Freeze network parameters for finetuning w/ cross-attention + self.has_cross_attention = getattr(params, "insert_cross_attn", False) + + # Absolute position embeddings + if self.params["apply_abs_pos_emb"]: + self.pos_emb_config = self._create_abs_pos_emb_config() + self.pos_emb, self.abs_pos_emb = self._initialize_abs_pos_emb() + self._broadcast_pos_emb(self.abs_pos_emb, tp_group) + + def _create_rope_config(self) -> Dict: + shape_map = { + "3D": self.params["video_latent_shape"], + "1D": None, + } + latent_shape = shape_map.get(self.params["rope_dim"], None) + head_dim = self.params["head_dim"] + if head_dim is None: + head_dim = self.params["dim"] // self.params["n_heads"] + return { + "dim": head_dim, + "max_position_embeddings": self.params["max_seq_len"], + "original_max_position_embeddings": self.params["original_seq_len"], + "rope_theta": self.params["rope_theta"], + "apply_yarn": self.params["apply_yarn"], + "scale": self.params["yarn_scale"], + "beta_fast": self.params["yarn_beta_fast"], + "beta_slow": self.params["yarn_beta_slow"], + "rope_dim": self.params["rope_dim"], + "latent_shape": latent_shape, + "original_latent_shape": self.params["original_latent_shape"], + "pad_to_multiple_of": self.params["pad_to_multiple_of"], + } + + def _create_abs_pos_emb_config(self): + shape_map = { + "3D": self.params["video_latent_shape"], + "1D": None, + } + latent_shape = shape_map.get(self.params["rope_dim"], None) + return { + "dim": self.params["dim"], + "latent_shape": latent_shape, + "pad_to_multiple_of": self.params["pad_to_multiple_of"], + } + + def _create_token_embeddings(self, model_parallel=None, vocab_size: int = None): + """ + Create token embeddings. + + Args: + model_parallel: The model parallel configuration. + + Returns: + nn.Module: Token embeddings module. + """ + if vocab_size is None: + vocab_size = self.params["vocab_size"] + tp_size = self.params["tensor_model_parallel_size"] + if tp_size > 1: + emb = VocabParallelEmbedding( + vocab_size, + self.params["dim"], + ).to(self.precision) + return emb + else: + return nn.Embedding(vocab_size, self.params["dim"]).to(self.precision) + + def _get_tp_group( + self, + ): + """ + Get tensor parallel process group if applicable. + + Returns: + torch.distributed.ProcessGroup or None: Tensor parallel process group if tensor parallelism is enabled, else None. + """ + if self.params["tensor_model_parallel_size"] > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + log.debug(f"Using tensor model parallel group: {tp_group}") + return tp_group + + return None + + def _create_output_projection(self, vocab_size: int = None): + """ + Create the output projection layer. + + Args: + tp_group (torch.distributed.ProcessGroup or None): Tensor parallel process group. + vocab_size (int): Vocabulary size (to override the default vocab size). + Returns: + LinearTE: Output projection layer. + """ + if vocab_size is None: + vocab_size = self.params["vocab_size"] + tp_size = self.params["tensor_model_parallel_size"] + return nn.Linear(self.params["dim"], vocab_size // tp_size, bias=False).to(self.precision) + + def _initialize_abs_pos_emb(self): + pos_emb = SinCosPosEmbAxisTE(**self.pos_emb_config) + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + abs_pos_emb = pos_emb.forward(training_type=training_type) + return pos_emb, abs_pos_emb + + def _broadcast_pos_emb(self, pos_emb, tp_group): + """ + Broadcast the position embeddings across the tensor parallel group. + + Args: + pos_emb (torch.Tensor): Position embeddings to broadcast. + tp_group (torch.distributed.ProcessGroup or None): Tensor parallel process group. + """ + if self.params["tensor_model_parallel_size"] > 1: + broadcast(pos_emb, min(get_process_group_ranks(tp_group)), group=tp_group) + + def forward( + self, + tokens: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + token_embeddings: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the Transformer module. + + Args: + tokens (torch.Tensor, optional): The input tensor of token IDs. + input_pos (Optional[torch.Tensor]): The position of the current sequence. Used in inference with KV cache. + token_embeddings (torch.Tensor, optional): Precomputed token embeddings. If provided, tokens should be None. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + Returns: + The output tensor after applying the transformer layers. + """ + # Token embeddings + assert ( + tokens is None or token_embeddings is None + ), "Either tokens or token_embeddings should be provided, not both." + + if token_embeddings is None: + seq_len = tokens.shape[1] + h = self.tok_embeddings(tokens) + else: + seq_len = token_embeddings.shape[1] + h = token_embeddings + + # Create attention mask + mask = self._create_attention_mask(input_pos=input_pos) + + # Prepare layer arguments + layer_kwargs = self._prepare_layer_kwargs( + input_pos=input_pos, + mask=mask, + context=context, + context_mask=context_mask, + ) + + # Apply transformer layers + for layer in self.layers: + if self.params["apply_abs_pos_emb"]: + h = self.apply_abs_pos_emb(h, input_pos=input_pos) + h = layer(h, **layer_kwargs) + + # Apply final layer normalization + h = self.norm(h) + + # Output linear projection + output = self.output(h) + if self.params["tensor_model_parallel_size"] > 1: + # Use PyTorch all gather + output = all_gather_tensor(output, gather_dim=-1, group=parallel_state.get_tensor_model_parallel_group()) + return output + + def _create_attention_mask(self, input_pos: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """ + Creates an attention mask for the transformer layers. + + Args: + input_pos[torch.Tensor]: The position of input sequence (used for inference only). + + Returns: + Optional[torch.Tensor]: The attention mask, or None for causal mask. + """ + + assert input_pos is not None, "input_pos must be provided for inference" + mask = self.causal_mask[input_pos] + return mask + + def _prepare_layer_kwargs( + self, + input_pos: Optional[torch.Tensor], + mask: Optional[torch.Tensor], + context: Optional[torch.Tensor], + context_mask: Optional[torch.Tensor], + ) -> Dict[str, Any]: + """ + Prepares the keyword arguments for transformer layers. + + Args: + input_pos (Optional[torch.Tensor]): The position of the current sequence. + mask (Optional[torch.Tensor]): The attention mask. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + Dict[str, Any]: A dictionary of keyword arguments for the transformer layers. + """ + if context is not None: + context = context.to(self.precision) + + if isinstance(mask, torch.Tensor) and mask.ndim == 2: + mask = mask[None, None, :, :] + if isinstance(context_mask, torch.Tensor) and context_mask.ndim == 2: + context_mask = context_mask[None, None, :, :] + + layer_kwargs = { + "mask": mask, + "context": context, + "context_mask": context_mask, + } + + layer_kwargs["input_pos"] = input_pos + layer_kwargs["rope"] = self.rope + + return layer_kwargs + + def apply_abs_pos_emb(self, x: torch.Tensor, input_pos: int = None) -> torch.Tensor: + """ + Applies the absolute position embeddings to the input tensor. + """ + abs_pos_emb = self.abs_pos_emb + abs_pos_emb = abs_pos_emb[:, input_pos, :] if input_pos is not None else abs_pos_emb + return x + abs_pos_emb + + @torch.no_grad() + def expand_vocab( + self, new_vocab_size: int, init_method: str = "gaussian", multiple_of=64, expand_output_layer=True + ): + """ + Expands the vocabulary of the model to the new size. + + Args: + new_vocab_size (int): The new vocabulary size. + init_method (str): The initialization method for new embeddings. + Can be "zero" or "gaussian". Default is "gaussian". + multiple_of (int): The new vocabulary size must be a multiple of this value. Defaults to 64 to fully + leverage the power of NVIDIA TensorCore (source 1: https://x.com/karpathy/status/1621578354024677377, + source 2: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc) + expand_output_layer (bool): Whether to also expand the output layer. Defaults to True. + + Returns: + None + """ + tp_size = self.params["tensor_model_parallel_size"] + if new_vocab_size <= self.vocab_size: + raise ValueError( + f"New vocabulary size ({new_vocab_size}) must be " f"larger than current size ({self.vocab_size})" + ) + if new_vocab_size % multiple_of != 0: + log.debug(f"New vocabulary size must be a multiple of {multiple_of}. Obtained {new_vocab_size}.") + new_vocab_size = (new_vocab_size // multiple_of + 1) * multiple_of + log.debug(f"Rounded vocabulary size to {new_vocab_size}.") + # Resize token embeddings + old_embeddings = self.tok_embeddings + tensor_kwargs = {"device": old_embeddings.weight.device, "dtype": old_embeddings.weight.dtype} + self.tok_embeddings = self._create_token_embeddings( + model_parallel=self.model_parallel, vocab_size=new_vocab_size + ).to(**tensor_kwargs) + # Initialize new embeddings + if init_method not in ["zero", "gaussian"]: + raise ValueError(f"Unknown initialization method: {init_method}") + # The default initialization of nn.Embedding is Gaussian, so we don't need to do anything + # if init_method == "gaussian". Only if init_method == "zero", we need to zero out the new embeddings. + if init_method == "zero": + self.tok_embeddings.weight.data[self.vocab_size // tp_size :].zero_() + + # Copy old embeddings + log.debug( + f"old_embeddings: {old_embeddings.weight.data.shape}, new_embeddings: {self.tok_embeddings.weight.data.shape}, vocab_size: {self.vocab_size}" + ) + self.tok_embeddings.weight.data[: self.vocab_size // tp_size] = old_embeddings.weight.data + # Resize output layer + old_output = self.output + self.output = self._create_output_projection(vocab_size=new_vocab_size if expand_output_layer else None) + + # Initialize new output weights + self.output.weight.data[self.vocab_size // tp_size :].zero_() + # Copy old output weights + self.output.weight.data[: self.vocab_size // tp_size] = old_output.weight.data + + # Update vocab size + self.vocab_size = new_vocab_size + log.debug(f"Expanded vocabulary size to {new_vocab_size}") + + def state_dict(self, *args, **kwargs): + """ + Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). + """ + state_dict = super().state_dict(*args, **kwargs) + return process_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + """ + Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by + TransformerEngine for FP8). + """ + state_dict = process_state_dict(state_dict) + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) + if strict: + actual_missing_keys = [] + for key in missing_keys: + if not any(substring in key for substring in substrings_to_ignore): + actual_missing_keys.append(key) + if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: + raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") + missing_keys = actual_missing_keys + return _IncompatibleKeys(missing_keys, unexpected_keys) diff --git a/cosmos_predict1/autoregressive/networks/vit.py b/cosmos_predict1/autoregressive/networks/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..46405ac37fe90c3e212e37493d68774702f02882 --- /dev/null +++ b/cosmos_predict1/autoregressive/networks/vit.py @@ -0,0 +1,412 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module implements a Vision Transformer (ViT) with 2D Rotary Position Embeddings, +designed for processing image inputs in vision-language models. + +This module follows Mistral's vision encoder implementation (for their Pistral-12B VLM): +https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py +""" +from functools import partial +from typing import Any, Callable, Mapping, Optional, Tuple + +import torch +import torch.nn as nn + +from cosmos_predict1.autoregressive.modules.normalization import create_norm +from cosmos_predict1.autoregressive.networks.transformer import TransformerBlock +from cosmos_predict1.utils import log + + +def get_vit_config(model_name: str) -> Mapping[str, Any]: + """ + Get the ViT configuration for a given model name. + """ + if model_name == "pixtral-12b-vit": + # The 400M ViT of Pixtral 12B VLM + return dict( + dim=1024, + num_channels=3, + image_size=1024, + patch_size=16, + rope_theta=10000, + ffn_hidden_size=4096, + n_layers=24, + n_heads=16, + n_kv_heads=16, + norm_type="rmsnorm", + norm_eps=1e-5, + image_token_id=10, + ) + else: + raise ValueError(f"Unknown model name: {model_name}") + + +def precompute_freqs_cis_2d( + dim: int, + height: int, + width: int, + theta: float, +) -> torch.Tensor: + """ + Precompute 2D complex tensor for rotary position embedding. + + This function generates a 2D complex tensor used for rotary position embeddings, + which helps the model understand spatial relationships in the input image. + + Args: + dim (int): Dimension of the model (typically the hidden size divided by number of heads). + height (int): Height of the image in patches. + width (int): Width of the image in patches. + theta (float): Base value for the angle calculation, controls the frequency range. + + Returns: + torch.Tensor: 2D complex tensor of shape (height, width, dim // 2). + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + + h = torch.arange(height, device=freqs.device) + w = torch.arange(width, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + freqs_2d = torch.cat( + [ + freqs_h[:, None, :].repeat(1, width, 1), + freqs_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ) + return torch.polar(torch.ones_like(freqs_2d), freqs_2d) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + """ + Reshape frequency tensor for broadcasting with input tensor. + + This function ensures that the frequency tensor can be properly broadcast + with the input tensor during the rotary embedding process. + + Args: + freqs_cis (torch.Tensor): Frequency tensor from precompute_freqs_cis_2d. + x (torch.Tensor): Input tensor to be embedded. + + Returns: + torch.Tensor: Reshaped frequency tensor ready for broadcasting. + """ + ndim = x.ndim + assert 0 <= 1 < ndim, f"ndim is {ndim} but index is {1}" + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape is {freqs_cis.shape} but x shape is {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + *args, + freqs_cis: torch.Tensor, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary positional embeddings to input tensors. + + This function applies the rotary positional embeddings to the query and key tensors, + which helps the model understand spatial relationships in the input. + + Args: + xq (torch.Tensor): Query tensor. + xk (torch.Tensor): Key tensor. + freqs_cis (torch.Tensor): Precomputed frequencies from precompute_freqs_cis_2d. + *args: Variable length argument list (unused). + **kwargs: Arbitrary keyword arguments (unused). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class VisionTransformer(nn.Module): + """ + Vision Transformer model for image processing. + + This class implements a Vision Transformer that processes images using a patch-based approach + and applies transformer layers with rotary position embeddings. + + Args: + dim (int): Dimension of the model (hidden size). + num_channels (int): Number of input image channels (e.g., 3 for RGB). + patch_size (int): Size of each image patch (e.g., 16x16 pixels). + n_layers (int): Number of transformer layers. + n_heads (int): Number of attention heads. + ffn_hidden_size (int): Hidden size of the feed-forward network in transformer blocks. + norm_type (str): Type of normalization to use (e.g., "rmsnorm"). + norm_eps (float): Epsilon value for normalization layers. + image_size (int): Size of the input image (assumed square). + rope_theta (float): Base value for rotary position embedding calculation. + attention_dropout (float): Dropout rate for attention layers. + hidden_dropout (float): Dropout rate for hidden layers. + image_token_id (int): Token ID for the image token (if present). + """ + + def __init__( + self, + dim: int = 1024, + num_channels: int = 3, + patch_size: int = 16, + n_layers: int = 24, + n_heads: int = 16, + n_kv_heads: int = None, + ffn_hidden_size: int = 4096, + norm_type: str = "rmsnorm", + norm_eps: float = 1e-5, + image_size: int = 1024, + rope_theta: float = 1000000.0, + image_token_id: int = None, + tensor_model_parallel_size: int = 1, + ): + super().__init__() + self.patch_conv = nn.Conv2d( + in_channels=num_channels, + out_channels=dim, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + self.ln_pre = create_norm(norm_type=norm_type, dim=dim, eps=norm_eps) + if n_kv_heads is None: + n_kv_heads = n_heads + layer_args = dict( + n_layers=n_layers, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + dim=dim, + use_qk_normalization=False, + max_seq_len=None, + max_batch_size=None, + ffn_hidden_size=ffn_hidden_size, + norm_type=norm_type, + norm_eps=norm_eps, + causal_mask=False, # Full attention in ViT + head_dim=None, + insert_cross_attn=False, + tensor_model_parallel_size=tensor_model_parallel_size, + attn_type="full", + ) + + self.transformer = VisionTransformerBlocks(n_layers=n_layers, args=layer_args) + + head_dim = dim // n_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + + self.dim = dim + self.n_heads = n_heads + self.max_patches_per_side = image_size // patch_size + self.image_size = image_size + self.patch_size = patch_size + self.rope_theta = rope_theta + self._freqs_cis: Optional[torch.Tensor] = None + self.image_token_id = image_token_id + + num_params = self.get_num_params() + log.debug(f"Number of model parameters: {round(num_params / 1e6, 3)}M") + + @classmethod + def build( + cls, + config: Mapping[str, Any], + ) -> "VisionTransformer": + """ + Create a Vision Transformer from a configuration dictionary. + + This class method creates a Vision Transformer from a configuration dictionary, + which is typically loaded from a JSON file or other configuration source. + + Args: + config (Mapping[str, Any]): Configuration dictionary for the Vision Transformer. + + Returns: + VisionTransformer: Vision Transformer model instance. + """ + necessary_keys = ["dim", "num_channels", "patch_size", "n_layers", "n_heads", "ffn_hidden_size", "rope_theta"] + missing_keys = [k for k in necessary_keys if k not in config] + assert len(missing_keys) == 0, f"Missing keys in config: {missing_keys}" + return cls( + **config, + ) + + def expand_in_channels(self, new_in_channels: int): + """ + Expand the input channels of the patch convolution layer. + This is useful when the input is non-standard, e.g. a 4-channel image with the last channel as the alpha channel. + Note that you should only call this method after the weight is loaded. + """ + assert ( + new_in_channels > self.patch_conv.in_channels + ), "Cannot expand the input channels of the patch convolution layer to be less than the original number of channels." + log.debug( + f"Vision encoder in_channels is {self.patch_conv.in_channels}. But you have specified to be {new_in_channels}. We will change it to {new_in_channels} channels with {new_in_channels - self.patch_conv.in_channels} channels of 0s." + ) + new_conv = nn.Conv2d( + in_channels=new_in_channels, + out_channels=self.patch_conv.out_channels, + kernel_size=self.patch_conv.kernel_size, + stride=self.patch_conv.stride, + bias=False, + ) + new_conv.weight.data[:, : self.patch_conv.in_channels].copy_(self.patch_conv.weight.data) + new_conv.weight.data[ + :, self.patch_conv.in_channels : + ].zero_() # zeroize, such that initially it has no effect to output + self.patch_conv = new_conv + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return next(self.parameters()).device + + @property + def freqs_cis(self) -> torch.Tensor: + """ + Get or compute the frequency tensor for rotary position embedding. + + This property lazily initializes and caches the frequency tensor used for + rotary position embeddings, ensuring it's on the correct device. + + Returns: + torch.Tensor: The frequency tensor for rotary position embeddings. + """ + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.dim // self.n_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the Vision Transformer. + + This method processes the input image through the Vision Transformer, + including patch embedding, position embedding, and transformer layers. + + Args: + x (torch.Tensor): Input tensor of shape (B, C, H, W), where B is batch size, + C is number of channels, and H, W are height and width. + + Returns: + torch.Tensor: Output features of shape (B, N, D), where N is the number of patches + and D is the embedding dimension. + """ + + patch_embeds = self.patch_conv(x) # (B, D, Hp, Wp) + _, _, Hp, Wp = patch_embeds.shape # Patch embeds dim + patch_embeds = patch_embeds.flatten(2) # (B, D, Hp*Wp) + patch_embeds = patch_embeds.transpose(1, 2) # (B, Hp*Wp, D) + patch_embeds = self.ln_pre(patch_embeds) # (B, Hp*Wp, D) + positions = torch.stack( + torch.meshgrid( + torch.arange(Hp), + torch.arange(Wp), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + rope = partial(apply_rotary_emb, freqs_cis=freqs_cis) + out = self.transformer(patch_embeds, rope=rope) + + return out + + def get_num_params( + self, + ) -> int: + """ + Return the number of parameters in the model. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + +class VisionTransformerBlocks(nn.Module): + """ + Vision Transformer Blocks. + + This class implements a stack of Transformer blocks used in the Vision Transformer. + + Args: + n_layers (int): Number of transformer layers. + args (Mapping[str, Any]): Arguments for each transformer block, including dimensions, + """ + + def __init__( + self, + n_layers: int, + args: Mapping[str, Any], + ): + super().__init__() + self.layers = torch.nn.ModuleList() + + for layer_id in range(n_layers): + self.layers.append( + TransformerBlock( + layer_id=layer_id, + args=args, + ) + ) + + def forward( + self, + x: torch.Tensor, + rope: Callable, + ) -> torch.Tensor: + """ + Forward pass through the Vision Transformer Blocks. + + This method applies a series of Transformer blocks to the input tensor, + using the provided rotary position embedding function. + + Args: + x (torch.Tensor): Input tensor of shape (B, N, D), where B is batch size, + N is the number of patches, and D is the embedding dimension. + rope (Callable): Rotary position embedding function to be applied in each layer. + + Returns: + torch.Tensor: Output tensor after passing through all transformer layers, + with the same shape as the input. + """ + for layer in self.layers: + x = layer(x, input_pos=None, mask=None, rope=rope) + return x diff --git a/cosmos_predict1/autoregressive/tokenizer/__init__.py b/cosmos_predict1/autoregressive/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/autoregressive/tokenizer/discrete_video.py b/cosmos_predict1/autoregressive/tokenizer/discrete_video.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd0a832b264387c28932c7dbc4dcf77fbbd935b --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/discrete_video.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from einops import rearrange + +from cosmos_predict1.autoregressive.tokenizer.quantizers import FSQuantizer + +# Make sure jit model output consistenly during consecutive calls +# Check here: https://github.com/pytorch/pytorch/issues/74534 +torch._C._jit_set_texpr_fuser_enabled(False) + + +def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: + """Loads a torch.jit.ScriptModule from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + # Make sure jit model output consistenly during consecutive calls + # Check here: https://github.com/pytorch/pytorch/issues/74534 + torch._C._jit_set_texpr_fuser_enabled(False) + + model = torch.jit.load(jit_filepath) + return model.eval().to(device) + + +class BaseDiscreteVideoFSQTokenizer(torch.nn.Module): + """ + A base class for Discrete Video FSQ Tokenizer that handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + Derived classes should load pre-trained encoder and decoder components into a encoder and decoder attributes. + + Attributes: + encoder (Module | Callable): Encoder loaded from storage. + decoder (Module | Callable): Decoder loaded from storage. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 6). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. + max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. + level (list[int]): The level defined in FSQ quantizer. + compression_ratio (list[int]): The compression factor for (T, H, W). + """ + + def __init__( + self, + name: str, + latent_ch: int = 6, + is_bf16: bool = True, + pixel_chunk_duration: int = 25, + latent_chunk_duration: int = 4, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + levels: list[int] = [8, 8, 8, 5, 5, 5], + compression_ratio: list[int] = [8, 16, 16], + ): + super().__init__() + self.channel = latent_ch + self.name = name + dtype = torch.bfloat16 if is_bf16 else torch.float32 + self.dtype = dtype + self.pixel_chunk_duration = pixel_chunk_duration + self.latent_chunk_duration = latent_chunk_duration + self.max_enc_batch_size = max_enc_batch_size + self.max_dec_batch_size = max_dec_batch_size + self.levels = levels + self.compress_ratio = compression_ratio + self.fsq_quantizer = FSQuantizer(levels) + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the tokenizer. + """ + return self.channel + + @torch.no_grad() + def encode(self, state: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: + B, C, T, H, W = state.shape + if pixel_chunk_duration is None: + # Use the default pixel chunk duration and latent chunk duration + pixel_chunk_duration = self.pixel_chunk_duration + latent_chunk_duration = self.latent_chunk_duration + else: + # Update the latent chunk duration based on the given pixel chunk duration + latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] + + assert ( + T % pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {pixel_chunk_duration}" + state = rearrange(state, "b c (n t) h w -> (b n) c t h w", t=pixel_chunk_duration) + + # use max_enc_batch_size to avoid OOM + if state.shape[0] > self.max_enc_batch_size: + quantized_out_list = [] + indices_list = [] + for i in range(0, state.shape[0], self.max_enc_batch_size): + indices, quantized_out, _ = self.encoder(state[i : i + self.max_enc_batch_size].to(self.dtype)) + quantized_out_list.append(quantized_out) + indices_list.append(indices) + quantized_out = torch.cat(quantized_out_list, dim=0) + indices = torch.cat(indices_list, dim=0) + else: + indices, quantized_out, _ = self.encoder(state.to(self.dtype)) + assert quantized_out.shape[2] == latent_chunk_duration + return rearrange(quantized_out, "(b n) c t h w -> b c (n t) h w", b=B), rearrange( + indices, "(b n) t h w -> b (n t) h w", b=B + ) + + @torch.no_grad() + def decode(self, indices: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: + B, T, _, _ = indices.shape + if pixel_chunk_duration is None: + pixel_chunk_duration = self.pixel_chunk_duration + latent_chunk_duration = self.latent_chunk_duration + else: + latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] + assert ( + T % latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {latent_chunk_duration}" + indices = rearrange(indices, "b (n t) h w -> (b n) t h w", t=latent_chunk_duration) + + # use max_dec_batch_size to avoid OOM + if indices.shape[0] > self.max_dec_batch_size: + state = [] + for i in range(0, indices.shape[0], self.max_dec_batch_size): + state.append(self.decoder(indices[i : i + self.max_dec_batch_size])) + state = torch.cat(state, dim=0) + else: + state = self.decoder(indices) + + assert state.shape[2] == pixel_chunk_duration + return rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.encoder.to(self.dtype) + + +class DiscreteVideoFSQJITTokenizer(BaseDiscreteVideoFSQTokenizer): + """ + A JIT compiled Discrete Video FSQ Tokenizer that loads pre-trained encoder + and decoder components from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The JIT compiled encoder loaded from storage. + decoder (Module): The JIT compiled decoder loaded from storage. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 6). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. + max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. + level (list[int]): The level defined in FSQ quantizer. + compression_ratio (list[int]): The compression factor for (T, H, W). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + latent_ch: int = 6, + is_bf16: bool = True, + pixel_chunk_duration: int = 25, + latent_chunk_duration: int = 4, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + levels: list[int] = [8, 8, 8, 5, 5, 5], + compression_ratio: list[int] = [8, 16, 16], + ): + super().__init__( + name, + latent_ch, + is_bf16, + pixel_chunk_duration, + latent_chunk_duration, + max_enc_batch_size, + max_dec_batch_size, + levels, + compression_ratio, + ) + + self.load_encoder(enc_fp) + self.load_decoder(dec_fp) + + def load_encoder(self, enc_fp: str) -> None: + """ + Load the encoder from the remote store. + + Args: + - enc_fp (str): File path to the encoder's JIT file on the remote store. + """ + self.encoder = load_jit_model(enc_fp, device="cuda") + self.encoder.eval() + for param in self.encoder.parameters(): + param.requires_grad = False + self.encoder.to(self.dtype) + + def load_decoder(self, dec_fp: str) -> None: + """ + Load the decoder from the remote store. + + Args: + - dec_fp (str): File path to the decoder's JIT file on the remote store. + """ + self.decoder = load_jit_model(dec_fp, device="cuda") + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + +class DiscreteVideoFSQStateDictTokenizer(BaseDiscreteVideoFSQTokenizer): + """ + A Discrete Video FSQ Tokenizer that loads weights from pre-trained JITed encoder + into as nn.Module so that encoder can be "torch.compile()" and JITed decoder, so it can be torch.compiled, + handles data type conversions, and normalization using provided mean and standard deviation values for latent + space representation. + + Attributes: + tokenizer_module (Module): Tokenizer module with weights loaded from JIT checkpoints + encoder (Callable): tokenizer_module's encode method + decoder (Callable): tokenizer_module's decode method + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + tokenizer_module (Module): Tokenizer module that will have it's weights loaded + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 6). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. + max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. + level (list[int]): The level defined in FSQ quantizer. + compression_ratio (list[int]): The compression factor for (T, H, W). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + tokenizer_module: torch.nn.Module, + name: str, + latent_ch: int = 6, + is_bf16: bool = True, + pixel_chunk_duration: int = 25, + latent_chunk_duration: int = 4, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + levels: list[int] = [8, 8, 8, 5, 5, 5], + compression_ratio: list[int] = [8, 16, 16], + ): + super().__init__( + name, + latent_ch, + is_bf16, + pixel_chunk_duration, + latent_chunk_duration, + max_enc_batch_size, + max_dec_batch_size, + levels, + compression_ratio, + ) + + self.load_encoder_and_decoder(enc_fp, dec_fp, tokenizer_module) + + def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, tokenizer_module: torch.nn.Module) -> None: + """ + Load the encoder from the remote store. + + Args: + - enc_fp (str): File path to the encoder's JIT file on the remote store. + - def_fp (str): File path to the decoder's JIT file on the remote store. + - tokenizer_module (Module): Tokenizer module that was used to create JIT checkpoints + """ + self.decoder = load_jit_model(dec_fp) + + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + encoder_sd = load_jit_model(enc_fp).state_dict() + + del tokenizer_module.post_quant_conv + del tokenizer_module.decoder + + state_dict = { + k: v + for k, v in (encoder_sd).items() + # Variables captured by JIT + if k + not in ( + "encoder.patcher3d.wavelets", + "encoder.patcher3d._arange", + "encoder.patcher3d.patch_size_buffer", + "quantizer._levels", + "quantizer._basis", + "quantizer.implicit_codebook", + ) + } + + tokenizer_module.load_state_dict(state_dict) + + tokenizer_module.eval() + for param in tokenizer_module.parameters(): + param.requires_grad = False + tokenizer_module.to(self.dtype) + + self.tokenizer_module = tokenizer_module + self.encoder = self.tokenizer_module.encode + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.tokenizer_module.to(self.dtype) diff --git a/cosmos_predict1/autoregressive/tokenizer/image_text_tokenizer.py b/cosmos_predict1/autoregressive/tokenizer/image_text_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..447092664fb89f7e205721571646b3ea29fe7d71 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/image_text_tokenizer.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import transformers +from transformers import AutoImageProcessor +from transformers.image_utils import ImageInput, is_valid_image, load_image + +from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer +from cosmos_predict1.utils import log + +# Configuration for different vision-language models +IMAGE_CONFIGS = { + "pixtral": { + "patch_size": 16, + "image_token": "[IMG]", + "image_break_token": "[IMG_BREAK]", + "image_end_token": "[IMG_END]", + } +} + +# Chat template for Pixtral-12B-Instruct +PIXTRAL_CHAT_TEMPLATE = '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["content"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {{- message["content"] + eos_token}}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}' + + +# Copied from transformers.models.pixtral.processing_pixtral.is_url +def is_url(val) -> bool: + """Check if the given value is a URL.""" + return isinstance(val, str) and val.startswith("http") + + +# Copied from transformers.models.pixtral.processing_pixtral.is_image_or_image_url +def is_image_or_image_url(elem): + """Check if the given element is an image or an image URL.""" + return is_url(elem) or is_valid_image(elem) + + +def load_image_list( + image_list: List[Union[str, "PIL.Image.Image"]], timeout: Optional[float] = None +) -> List["PIL.Image.Image"]: + """ + Load a list of images. + + Args: + image_list (List[Union[str, PIL.Image.Image]]): The list of images to load. + timeout (Optional[float]): The timeout for loading the image. + + Returns: + List[PIL.Image.Image]: The list of loaded images. + """ + return [load_image(image, timeout=timeout) for image in image_list] + + +class ImageTextTokenizer(TextTokenizer): + """ + Image-text tokenizer class that extends the text tokenizer to support vision tokens as well. + """ + + def __init__( + self, + model_family: str, + is_instruct_model: bool, + tokenizer_path: str, + image_processor_path: str, + ): + """ + Initialize the ImageTextTokenizer. + + Args: + model_family (str): The model family. + is_instruct_model (bool): Whether the model is an instruct model. + + Raises: + AssertionError: If the model family is not supported or if the transformers version is incompatible. + """ + super().__init__( + model_family=model_family, + is_instruct_model=is_instruct_model, + local_path=tokenizer_path, + ) + assert model_family in ["pixtral"], f"Unsupported model family: {model_family}" + if model_family == "pixtral": + # Need transformers>=4.45.0 + assert transformers.__version__ >= "4.45.0", "Pixtral requires transformers>=4.45.0" + assert is_instruct_model, "Pixtral requires is_instruct_model=True" + if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None: + setattr(self.tokenizer, "chat_template", PIXTRAL_CHAT_TEMPLATE) + log.debug(f"Pixtral tokenizer chat template set to: {PIXTRAL_CHAT_TEMPLATE}") + + # Set up image-specific configurations + image_config = IMAGE_CONFIGS[model_family] + self.patch_size = image_config["patch_size"] + self.image_token = image_config["image_token"] + self.image_break_token = image_config["image_break_token"] + self.image_end_token = image_config["image_end_token"] + + # Initialize the image processor + self.image_processor = AutoImageProcessor.from_pretrained(image_processor_path) + + def encode( + self, + text: Union[str, List[str], List[int]], + *, # Enforce keyword-only arguments + images: Optional[ImageInput] = None, + image_kwargs: Optional[Dict[str, Any]] = None, + **text_kwargs, + ) -> List[int]: + """ + Process the images and return the tokenized images and text. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. + image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing. + **text_kwargs: Additional keyword arguments for text processing. + + Returns: + A dictionary with the following fields: + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **pixel_values** -- Pixel values to be fed to a model. + + Raises: + ValueError: If the input images are in an invalid format. + """ + + output_dict, image_inputs = {}, {} + + if images is not None: + if is_image_or_image_url(images): + images = [images] + elif isinstance(images, (list, tuple)) and is_image_or_image_url(images[0]): + pass + elif ( + isinstance(images, (list, tuple)) + and isinstance(images[0], (list, tuple)) + and is_image_or_image_url(images[0][0]) + ): + images = [image for sublist in images for image in sublist] + else: + raise ValueError( + "Invalid input images. Please provide a single image, a list of images, or a list of lists of images." + ) + images = [load_image(im) if isinstance(im, str) else im for im in images] + image_kwargs = image_kwargs or {} + image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors="np", **image_kwargs) + + # Validate image inputs + assert "pixel_values" in image_inputs, "pixel_values not found in image_inputs" + assert "image_sizes" in image_inputs, "image_sizes not found in image_inputs" + assert len(image_inputs.keys()) == 2, "Only one key is allowed in image_inputs, got {}".format( + image_inputs.keys() + ) + + # Extract pixel values and image sizes + pixel_values = image_inputs["pixel_values"] + image_sizes = image_inputs["image_sizes"] + unique_sizes = np.unique(image_sizes, axis=0) + + assert len(unique_sizes) == 1, "All images must have the same size, got {}".format(unique_sizes) + + # Convert pixel values to PyTorch tensor + pixel_values = np.asarray(pixel_values) + pixel_values = torch.from_numpy(pixel_values) + output_dict["pixel_values"] = pixel_values + output_dict["image_sizes"] = image_sizes + + # Expand image tokens in text + if image_inputs.get("pixel_values") is not None: + replace_strings = [] + # Calculate the number of tokens needed for each image and create a placeholder + for image_size in image_sizes: + height, width = image_size + num_height_tokens = height // self.patch_size + num_width_tokens = width // self.patch_size + replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens + # Flatten list + replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens[-1] = self.image_end_token + replace_str = "".join(replace_tokens) + replace_strings.append(replace_str) + text = text.replace(self.image_token, "", 1) + + # Replace placeholders with actual image token sequences + while "" in text: + replace_str = replace_strings.pop(0) + text = text.replace("", replace_str, 1) + + # Encode the text + text_inputs = super(ImageTextTokenizer, self).encode(text, **text_kwargs) + + output_dict["input_ids"] = text_inputs + return output_dict + + def apply_chat_template( + self, + conversation: List[Dict[str, Any]] | List[List[Dict[str, Any]]], + *, + images: Optional[ImageInput] = None, + image_kwargs: Optional[Dict[str, Any]] = None, + add_generation_prompt: bool = False, + tokenize: bool = True, + padding: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[str] = None, + return_dict: bool = True, + return_assistant_tokens_mask: bool = False, + generation_prefix: str = "", + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Apply the chat template to the conversation. + + Args: + conversation (List[Dict[str, Any]] | List[List[Dict[str, Any]]]): The conversation to process. + images (Optional[ImageInput]): Images to include in the conversation. + image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing. + add_generation_prompt (bool): Whether to add a generation prompt. + tokenize (bool): Whether to tokenize the output. + padding (bool): Whether to pad the output. + truncation (bool): Whether to truncate the output. + max_length (Optional[int]): Maximum length of the output. + return_tensors (Optional[str]): The type of tensors to return. + return_dict (bool): Whether to return a dictionary. + return_assistant_tokens_mask (bool): Whether to return the assistant tokens mask. + generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "". + tokenizer_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer. + **kwargs: Additional keyword arguments. + + Returns: + The processed conversation with applied chat template. + + Raises: + AssertionError: If return_dict is False or if the conversation format is invalid. + """ + assert return_dict, "return_dict must be True for ImageTextTokenizer" + assert isinstance(conversation, list), "conversation must be a list" + if isinstance(conversation[0], list): + assert len(conversation) == 1, "Only support single-conversation input, got {}".format(conversation) + conversation = conversation[0] + + # Extract images from the conversation if not provided + if images is None: + images = [] + for msg in conversation: + if msg.get("images", None) is not None: + images = images + (msg["images"]) + images = load_image_list(images) + # In case the input does not have images, will ignore + # Useful in feeding VLM inputs with and without images + if isinstance(images, list) and len(images) == 0: + images = None + + # Apply the chat template to the text + text = super().apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=add_generation_prompt, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_dict=False, + return_assistant_tokens_mask=return_assistant_tokens_mask, + generation_prefix=generation_prefix, + tokenizer_kwargs=tokenizer_kwargs, + **kwargs, + ) + + if tokenizer_kwargs is None: + tokenizer_kwargs = {} + + # Encode the text and images + output = self.encode( + text, + images=images, + image_kwargs=image_kwargs, + tokenize=tokenize, + padding=padding, + truncation=truncation, + max_length=max_length, + add_special_tokens=False, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) + return output + + @property + def model_input_names(self): + """ + Get the combined model input names from both the text tokenizer and image processor. + + Returns: + List[str]: A list of unique input names. + """ + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/cosmos_predict1/autoregressive/tokenizer/modules.py b/cosmos_predict1/autoregressive/tokenizer/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..68fee5493ec40f3ba4ba205eb2e3c26dd1c5c9f0 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/modules.py @@ -0,0 +1,560 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The model definition for 3D layers + +Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/ +magvit2_pytorch/magvit2_pytorch.py#L889 + +[MIT License Copyright (c) 2023 Phil Wang] +https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/LICENSE +""" +import math +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from cosmos_predict1.autoregressive.tokenizer.patching import Patcher3D, UnPatcher3D +from cosmos_predict1.autoregressive.tokenizer.utils import ( + CausalNormalize, + batch2space, + batch2time, + cast_tuple, + is_odd, + nonlinearity, + replication_pad, + space2batch, + time2batch, +) +from cosmos_predict1.utils import log + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in: int = 1, + chan_out: int = 1, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + pad_mode: str = "constant", + **kwargs, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + time_stride = kwargs.pop("time_stride", 1) + time_dilation = kwargs.pop("time_dilation", 1) + padding = kwargs.pop("padding", 1) + + self.pad_mode = pad_mode + time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride) + self.time_pad = time_pad + + self.spatial_pad = (padding, padding, padding, padding) + + stride = (time_stride, stride, stride) + dilation = (time_dilation, dilation, dilation) + self.conv3d = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def _replication_pad(self, x: torch.Tensor) -> torch.Tensor: + x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1) + x = torch.cat([x_prev, x], dim=2) + padding = self.spatial_pad + (0, 0) + return F.pad(x, padding, mode=self.pad_mode, value=0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._replication_pad(x) + return self.conv3d(x) + + +class CausalHybridUpsample3d(nn.Module): + def __init__(self, in_channels: int, spatial_up: bool = True, temporal_up: bool = True, **ignore_kwargs) -> None: + super().__init__() + self.conv1 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=1, padding=0) + if temporal_up + else nn.Identity() + ) + self.conv2 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=1, time_stride=1, padding=1) + if spatial_up + else nn.Identity() + ) + self.conv3 = ( + CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0) + if spatial_up or temporal_up + else nn.Identity() + ) + self.spatial_up = spatial_up + self.temporal_up = temporal_up + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_up and not self.temporal_up: + return x + + # hybrid upsample temporally. + if self.temporal_up: + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + x = x[..., int(time_factor - 1) :, :, :] + x = self.conv1(x) + x + + # hybrid upsample spatially. + if self.spatial_up: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + x = self.conv2(x) + x + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalHybridDownsample3d(nn.Module): + def __init__( + self, in_channels: int, spatial_down: bool = True, temporal_down: bool = True, **ignore_kwargs + ) -> None: + super().__init__() + self.conv1 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=2, time_stride=1, padding=0) + if spatial_down + else nn.Identity() + ) + self.conv2 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=2, padding=0) + if temporal_down + else nn.Identity() + ) + self.conv3 = ( + CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0) + if spatial_down or temporal_down + else nn.Identity() + ) + self.spatial_down = spatial_down + self.temporal_down = temporal_down + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_down and not self.temporal_down: + return x + + # hybrid downsample spatially. + if self.spatial_down: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x1 = self.conv1(x) + x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + x = x1 + x2 + + # hybrid downsample temporally. + if self.temporal_down: + x = replication_pad(x) + x1 = self.conv2(x) + x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1)) + x = x1 + x2 + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalResnetBlockFactorized3d(nn.Module): + def __init__(self, *, in_channels: int, out_channels: int = None, dropout: float, num_groups: int) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=1) + self.conv1 = nn.Sequential( + CausalConv3d(in_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Sequential( + CausalConv3d(out_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size = time2batch(q) + k, batch_size = time2batch(k) + v, batch_size = time2batch(v) + + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = batch2time(h_, batch_size) + h_ = self.proj_out(h_) + return x + h_ + + +class CausalTemporalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size, height = space2batch(q) + k, _, _ = space2batch(k) + v, _, _ = space2batch(v) + + bhw, c, t = q.shape + q = q.permute(0, 2, 1) # (bhw, t, c) + k = k.permute(0, 2, 1) # (bhw, t, c) + v = v.permute(0, 2, 1) # (bhw, t, c) + + w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t) + w_ = w_ * (int(c) ** (-0.5)) + + # Apply causal mask + mask = torch.tril(torch.ones_like(w_)) + w_ = w_.masked_fill(mask == 0, float("-inf")) + w_ = F.softmax(w_, dim=2) + + # attend to values + h_ = torch.bmm(w_, v) # (bhw, t, c) + h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t) + + h_ = batch2space(h_, batch_size, height) + h_ = self.proj_out(h_) + return x + h_ + + +class EncoderFactorized(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + temporal_compression: int, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size * patch_size + + # calculate the number of downsample operations + self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert ( + self.num_spatial_downs <= self.num_resolutions + ), f"Spatially downsample {self.num_resolutions} times at most" + + self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_downs <= self.num_resolutions + ), f"Temporally downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = nn.Sequential( + CausalConv3d(in_channels, channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1 + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + ) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + spatial_down = i_level < self.num_spatial_downs + temporal_down = i_level < self.num_temporal_downs + down.downsample = CausalHybridDownsample3d( + block_in, spatial_down=spatial_down, temporal_down=temporal_down + ) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(z_channels, z_channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderFactorized(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + temporal_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size * patch_size + + # calculate the number of upsample operations + self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most" + self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_ups <= self.num_resolutions + ), f"Temporally upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + log.debug("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = nn.Sequential( + CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1 + ) + + legacy_mode = ignore_kwargs.get("legacy_mode", False) + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1 + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1) + ) + ) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + # The layer index for temporal/spatial downsampling performed in the encoder should correspond + # to the layer index, inreverse order, where upsampling is performed in the decoder. + # If you've a pre-trained model, you can simply finetune. + # For example: + # Input tensor = (1, 3, 17, 32, 32) + # Patch size = 4 for 3D wavelet transform + # Compression rate = (8x16x16) + # + # We expect successive downsampling in the encoder and upsampling in the decoder to be mirrored. + # ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)` + # DECODER: `(...,3,2,2) -> (...,3,4,4) -> (...,5,8,8)` + # + # if legacy_mode is True, the temporal upsampling is not perfectly mirrored. + # ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)` + # DECODER: `(...,3,2,2) -> (...,5,4,4) -> (...,5,8,8)` + # + # Most of the CV and DV tokenizers were trained before 09/01/2024 with upsampling that's not mirrored. + # Going forward, new CV/DV tokenizers will adopt `legacy_mode=False`, i.e. use mirrored upsampling. + i_level_reverse = self.num_resolutions - i_level - 1 + if legacy_mode: + temporal_up = i_level_reverse < self.num_temporal_ups + else: + temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1 + spatial_up = temporal_up or ( + i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups + ) + up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h diff --git a/cosmos_predict1/autoregressive/tokenizer/networks.py b/cosmos_predict1/autoregressive/tokenizer/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..56b0c5fb7a1dec7e6282c66f7a34253925c11ffe --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/networks.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple + +import torch +from torch import nn + +from cosmos_predict1.autoregressive.tokenizer.modules import CausalConv3d, DecoderFactorized, EncoderFactorized +from cosmos_predict1.autoregressive.tokenizer.quantizers import FSQuantizer +from cosmos_predict1.utils import log + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) + + +class CausalDiscreteVideoTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer") + self.embedding_dim = embedding_dim + self.encoder = EncoderFactorized(z_channels=z_factor * z_channels, **kwargs) + self.decoder = DecoderFactorized(z_channels=z_channels, **kwargs) + + self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) + self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) + + self.quantizer = FSQuantizer(**kwargs) + + num_parameters = sum(param.numel() for param in self.parameters()) + log.debug(f"model={self.name}, num_parameters={num_parameters:,}") + log.debug(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") + + def to(self, *args, **kwargs): + setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) + return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantizer(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + return self.decoder(quant) + + def forward(self, input): + quant_info, quant_codes, quant_loss = self.encode(input) + reconstructions = self.decode(quant_codes) + if self.training: + return dict(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info) + return NetworkEval(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info) diff --git a/cosmos_predict1/autoregressive/tokenizer/patching.py b/cosmos_predict1/autoregressive/tokenizer/patching.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5b621f9d526cff7966c77225656e9327adde30 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/patching.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The patcher and unpatcher implementation for 2D and 3D data.""" + +import torch +import torch.nn.functional as F +from einops import rearrange + +_WAVELETS = { + "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]), + "rearrange": torch.tensor([1.0, 1.0]), +} +_PERSISTENT = False + + +class Patcher(torch.nn.Module): + """A module to convert image tensors into patches using torch operations. + + The main difference from `class Patching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Patching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=_PERSISTENT) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._haar(x) + elif self.patch_method == "rearrange": + return self._arrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) + xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) + xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) + xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) + + out = torch.cat([xll, xlh, xhl, xhh], dim=1) + if rescale: + out = out / 2 + return out + + def _haar(self, x): + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + x = rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=self.patch_size, p2=self.patch_size).contiguous() + return x + + +class Patcher3D(Patcher): + """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + self.register_buffer( + "patch_size_buffer", patch_size * torch.ones([1], dtype=torch.int32), persistent=_PERSISTENT + ) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + # Handles temporal axis. + x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + # Handles spatial axes. + xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1) + if rescale: + out = out / (2 * torch.sqrt(torch.tensor(2.0))) + return out + + def _haar(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + x = rearrange( + x, + "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ).contiguous() + return x + + +class UnPatcher(torch.nn.Module): + """A module to convert patches into image tensorsusing torch operations. + + The main difference from `class Unpatching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Unpatching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=_PERSISTENT) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._ihaar(x) + elif self.patch_method == "rearrange": + return self._iarrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _idwt(self, x, rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 4 + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) + + # Inverse transform. + yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + + if rescale: + y = y * 2 + return y + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, rescale=True) + return x + + def _iarrange(self, x): + x = rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=self.patch_size, p2=self.patch_size) + return x + + +class UnPatcher3D(UnPatcher): + """A 3D inverse discrete wavelet transform for video wavelet decompositions.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + + def _idwt(self, x, rescale=False): + dtype = x.dtype + h = self.wavelets + + g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors. + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hl = hl.to(dtype=dtype) + hh = hh.to(dtype=dtype) + + xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + + # Height height transposed convolutions. + xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + # Handles width transposed convolutions. + xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + # Handles time axis transposed convolutions. + x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + if rescale: + x = x * (2 * torch.sqrt(torch.tensor(2.0))) + return x + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, rescale=True) + x = x[:, :, self.patch_size - 1 :, ...] + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ) + x = x[:, :, self.patch_size - 1 :, ...] + return x diff --git a/cosmos_predict1/autoregressive/tokenizer/quantizers.py b/cosmos_predict1/autoregressive/tokenizer/quantizers.py new file mode 100644 index 0000000000000000000000000000000000000000..5618d1fbffa4fc68133332de23e3d8fb5dd03dc5 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/quantizers.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantizers for discrete image and video tokenization.""" + +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange + +from cosmos_predict1.autoregressive.tokenizer.utils import default, pack_one, round_ste, unpack_one + + +class FSQuantizer(nn.Module): + """Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/finite_scalar_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + levels: list[int], + dim: Optional[int] = None, + num_codebooks=1, + keep_num_codebooks_dim: Optional[bool] = None, + scale: Optional[float] = None, + **ignore_kwargs, + ): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.float32) + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("_levels", _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32) + self.register_buffer("_basis", _basis, persistent=False) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + has_projections = self.dim != effective_codebook_dim + self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.codebook_size = self._levels.prod().item() + + implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) + self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) + + def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat).float() + return (zhat * self._basis).sum(dim=-1).to(torch.int32) + + def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor: + """Inverse of `codes_to_indices`.""" + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + indices = rearrange(indices, "... -> ... 1") + codes_non_centered = (indices // self._basis) % self._levels + codes = self._scale_and_shift_inverse(codes_non_centered) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, "... c d -> ... (c d)") + + if project_out: + codes = self.project_out(codes) + + if is_img_or_video: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes.to(self.dtype) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + is_img_or_video = z.ndim >= 4 + + # standardize image or video into (batch, seq, dimension) + + if is_img_or_video: + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + + assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + + z = self.project_in(z) + + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + codes = self.quantize(z) + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, "b n c d -> b n (c d)") + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if is_img_or_video: + out = unpack_one(out, ps, "b * d") + out = rearrange(out, "b ... d -> b d ...") + indices = unpack_one(indices, ps, "b * c") + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True)) + else: + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1) + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... 1 -> ...") + + return (indices, out.to(self.dtype), dummy_loss) diff --git a/cosmos_predict1/autoregressive/tokenizer/text_tokenizer.py b/cosmos_predict1/autoregressive/tokenizer/text_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6e8cb96aefc49922d24173cdb6af1c58fec231bc --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/text_tokenizer.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import AutoTokenizer + +from cosmos_predict1.utils import log + + +def get_tokenizer_path(model_family: str, is_instruct_model: bool = False): + """ + Get the tokenizer path from the model family and instruct model flag. + Args: + model_family (str): The model family. + is_instruct_model (bool): Whether the model is an instruct model. + Returns: + str: The tokenizer path. + """ + model_family = model_family.lower() + if model_family == "mistral": + return "mistralai/Mistral-Nemo-Instruct-2407" + else: + assert model_family in ["llama3", "llama3.1"] + if model_family == "llama3": + model_path = "meta-llama/Meta-Llama-3-8B" + elif model_family == "llama3.1": + model_path = "meta-llama/Llama-3.1-8B" + else: + raise ValueError(f"Unsupported model family: {model_family}") + suffix = "-Instruct" if is_instruct_model else "" + model_path = f"{model_path}{suffix}" + return model_path + + +class TextTokenizer: + """ + Text tokenizer class built on HuggingFace's Fast Tokenizer (Rust based). + """ + + def __init__( + self, + model_family: str, + is_instruct_model: bool, + local_path: Optional[str] = None, + ): + """ + Initialize the TextTokenizer. + Args: + model_family (str): The model family. + is_instruct_model (bool): Whether the model is an instruct model. + local_path (Optional[str]): The local path to the tokenizer. If not provided, the tokenizer will be downloaded from the remote path. + """ + if local_path is None: + tokenizer_path = get_tokenizer_path(model_family, is_instruct_model) + else: + tokenizer_path = local_path + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) + self.stop_tokens = { + self.tokenizer.eos_token_id, + } + self.model_family = model_family + self.is_instruct_model = is_instruct_model + self.eos_id = self.tokenizer.eos_token_id + if self.tokenizer.pad_token is None: + if model_family.startswith("llama"): + self.pad_id = 128004 # "<|finetune_right_pad_id|>" + elif model_family == "mistral": + self.pad_id = 10 # "" + elif model_family == "pixtral": + self.pad_id = 11 # "" + else: + raise ValueError(f"pad_id not defined for model_family {model_family}") + else: + self.pad_id = self.tokenizer.pad_token_id + + def tokenize(self, text: str, *, add_special_tokens: bool = False, **kwargs) -> List[str]: + """ + Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`. + + Args: + text (`str`): + The sequence to be encoded. + add_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add the special tokens associated with the corresponding model. + Returns: + `List[str]`: The list of tokens. + """ + return self.tokenizer.tokenize(text, add_special_tokens=add_special_tokens, **kwargs) + + def encode( + self, + text: Union[str, List[str], List[int]], + *, # Enforce keyword-only arguments + add_special_tokens: bool = True, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = None, + max_length: Optional[int] = None, + stride: int = 0, + return_tensors: Optional[str] = None, + **kwargs, + ) -> List[int]: + """ + Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. + + Args: + text (`str`, `List[str]` or `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to add special tokens when encoding the sequences. This will use the underlying + `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are + automatically added to the input ids. This is usefull if you want to add `bos` or `eos` tokens + automatically. + padding (`bool`, `str`, *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str`, *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + """ + return self.tokenizer.encode( + text, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + return_tensors=return_tensors, + ) + + def decode( + self, + token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor"], + *, # Enforce keyword-only arguments + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether or not to clean up the tokenization spaces. If `None`, will default to + `self.clean_up_tokenization_spaces`. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `str`: The decoded sentence. + """ + return self.tokenizer.decode( + token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def apply_chat_template( + self, + conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + *, + add_generation_prompt: bool = False, + tokenize: bool = True, + padding: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[str] = None, + return_dict: bool = False, + return_assistant_tokens_mask: bool = False, + generation_prefix: str = "", + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token + ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to determine the format and control tokens to use when converting. + + More details can be found at https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template + + Args: + conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts + with "role" and "content" keys, representing the chat history so far. + add_generation_prompt (bool, *optional*): + If this is set, a prompt with the token(s) that indicate + the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model. + Note that this argument will be passed to the chat template, and so it must be supported in the + template for this argument to have any effect. + continue_final_message (bool, *optional*): + If this is set, the chat will be formatted so that the final + message in the chat is open-ended, without any EOS tokens. The model will continue this message + rather than starting a new one. This allows you to "prefill" part of + the model's response for it. Cannot be used at the same time as `add_generation_prompt`. + tokenize (`bool`, defaults to `True`): + Whether to tokenize the output. If `False`, the output will be a string. + padding (`bool`, defaults to `False`): + Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`. + truncation (`bool`, defaults to `False`): + Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`. + max_length (`int`, *optional*): + Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If + not specified, the tokenizer's `max_length` attribute will be used as a default. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable + values are: + - `'tf'`: Return TensorFlow `tf.Tensor` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + return_dict (`bool`, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. + generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "". + tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. + return_assistant_tokens_mask (`bool`, defaults to `False`): + Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, + the mask will contain 1. For user and system tokens, the mask will contain 0. + This functionality is only available for chat templates that support it via the `{% generation %}` keyword. + **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. + + Returns: + `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This + output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is + set, will return a dict of tokenizer outputs instead. + """ + if not self.is_instruct_model: + raise ValueError( + "apply_chat_template is only supported for instruct models. You should pass argument is_instruct_model=True to the TextTokenizer constructor." + ) + # Since generation_prefix is added to the text in the end, ensure that the setting is correct + if generation_prefix: + assert not tokenize, "tokenize must be False when generation_prefix is provided." + assert add_generation_prompt, "add_generation_prompt must be set when generation_prefix is provided." + formatted_text: Union[str, List[int]] = self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_dict=return_dict, + return_assistant_tokens_mask=return_assistant_tokens_mask, + tokenizer_kwargs=tokenizer_kwargs, + **kwargs, + ) + if generation_prefix: + formatted_text: str = formatted_text + generation_prefix + log.debug( + f"Adding generation prefix: {generation_prefix} to the formatted text\n" + f"Formatted text: {formatted_text}" + ) + return formatted_text diff --git a/cosmos_predict1/autoregressive/tokenizer/tokenizer.py b/cosmos_predict1/autoregressive/tokenizer/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbda5f7c4eccc36ba85eca89fe5dd841fcc0786 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/tokenizer.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import Optional + +import torch +from einops import rearrange + +from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig +from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate + + +def update_vocab_size( + existing_vocab_size, + to_be_added_vocab_size, + training_type, + add_special_tokens, + video_special_tokens={}, +): + # New vocab size + if add_special_tokens: + existing_vocab_size += to_be_added_vocab_size + len(video_special_tokens) + # For text_to_video, we add one special token at the beginning of the video + elif training_type == "text_to_video": + existing_vocab_size += to_be_added_vocab_size + 1 + else: + existing_vocab_size += to_be_added_vocab_size + return existing_vocab_size + + +class DiscreteMultimodalTokenizer: + def __init__(self, tokenizer_config: TokenizerConfig): + self.tokenizer_config = tokenizer_config + self.vocab_size = 0 + self.total_seq_len = tokenizer_config.seq_len + self.pad_to_multiple_of = tokenizer_config.pad_to_multiple_of + self.training_type = tokenizer_config.training_type + assert self.training_type in [ + "text_only", + "text_to_video", + "video_to_video", + "image_text_interleaved", + ], f"{self.training_type} not supported" + + self._build_text_tokenizer() + self._build_video_tokenizer() + + def _build_text_tokenizer(self): + r"""Function to initialize the text tokenizer model.""" + if self.tokenizer_config.text_tokenizer is not None: + self.text_tokenizer = lazy_instantiate(self.tokenizer_config.text_tokenizer.config) + self.vocab_size += self.tokenizer_config.text_tokenizer.vocab_size + else: + self.text_tokenizer = None + + def _build_video_tokenizer(self): + r"""Function to initialize the video tokenizer model.""" + if self.tokenizer_config.video_tokenizer is not None: + self.video_tokenizer = lazy_instantiate(self.tokenizer_config.video_tokenizer.config) + self.video_tokenizer = self.video_tokenizer.to("cuda") + self.video_vocab_size = self.tokenizer_config.video_tokenizer.vocab_size + special_token_offset = ( + self.tokenizer_config.video_tokenizer.tokenizer_offset + + self.tokenizer_config.video_tokenizer.vocab_size + ) + self.video_special_tokens = { + "<|begin_of_video|>": special_token_offset, + "<|end_of_video|>": special_token_offset + 1, + "<|pad_token_video|>": special_token_offset + 2, + } + + self.vocab_size = update_vocab_size( + existing_vocab_size=self.vocab_size, + to_be_added_vocab_size=self.tokenizer_config.video_tokenizer.vocab_size, + training_type=self.training_type, + add_special_tokens=self.tokenizer_config.add_special_tokens, + video_special_tokens=self.video_special_tokens, + ) + else: + self.video_tokenizer = None + + @property + def pad_id(self): + r"""Returns the pad_id.""" + + if self.training_type == "text_only" or self.training_type == "image_text_interleaved": + pad_id = self.text_tokenizer.pad_id + elif self.training_type in ["text_to_video", "video_to_video"]: + pad_id = self.video_special_tokens["<|pad_token_video|>"] + else: + raise ValueError(f"training_type {self.training_type} not defined") + return pad_id + + @property + def ignore_index(self): + r"""Returns which token should be ignored during loss computation.""" + if self.training_type == "text_only" or self.training_type == "image_text_interleaved": + if self.text_tokenizer.pad_id == self.text_tokenizer.eos_id: + # If the PAD token is the same as the EOS token, we do not ignore it during loss + # computation, since we want the model to be able to predict EOS tokens in inference. + # The PyTorch default ignore_index for the cross-entropy loss is -100. + ignore_index = -100 + else: + ignore_index = self.text_tokenizer.pad_id + elif self.training_type in ["text_to_video", "video_to_video"]: + ignore_index = self.pad_id + else: + raise ValueError(f"training_type {self.training_type} not defined") + return ignore_index + + @property + def stop_tokens(self): + r"""Returns the stop tokens.""" + if self.training_type == "text_only" or self.training_type == "image_text_interleaved": + stop_tokens = self.text_tokenizer.stop_tokens + elif self.training_type in ["text_to_video", "video_to_video"]: + stop_tokens = set([self.video_special_tokens["<|end_of_video|>"]]) + else: + raise ValueError(f"training_type {self.training_type} not defined") + return stop_tokens + + def _tokenize_text(self, raw_text: list[str], max_text_seq_len: int = -1): + r"""Function to tokenize text. + Args: + raw_text (list[str]): List of input strings + max_text_seq_len (int): Maximum sequence length returned by text tokenizer + Returns: + text_tokens (list[list[int]]): List of text tokens + """ + + batch_size = len(raw_text) + text_tokens = [self.text_tokenizer.encode(raw_text[i], bos=True, eos=True) for i in range(batch_size)] + + # Clipping the text tokens so that the sequence length does not exceed max_text_seq_len + if max_text_seq_len > -1: + for i in range(len(text_tokens)): + if len(text_tokens[i]) > max_text_seq_len: + # Simply clip and add end of seq token + text_tokens[i] = text_tokens[i][0 : max_text_seq_len - 1] + [self.text_tokenizer.eos_id] + return text_tokens + + def _tokenize_class(self, cls_labels: list[str]): + r"""Function to tokenize the class label. + Args: + cls_labels (list[str]): List of class indices + Returns: + class_tokens (list[list[int]]): List of class tokens + """ + + # tokenizer_offset tells what offset should be added to the tokens. + # This is needed for vocab expansion. + class_tokens = [[int(x) + self.tokenizer_config.class_tokenizer.tokenizer_offset] for x in cls_labels] + + return class_tokens + + def _tokenize_video(self, videos: torch.Tensor, pixel_chunk_duration: Optional[int] = None): + r"""Function to tokenize video. + Args: + videos (torch.Tensor): Input video data tensor + pixel_chunk_duration (Optional[float]): Pixel chunk duration. If provided, we pass it to the video tokenizer. + Returns: + video_tokens (list[list[int]]): List of video tokens + """ + + video_tokens = [] + batch_size = videos.shape[0] + + quantized_out, _ = self.video_tokenizer.encode(videos, pixel_chunk_duration=pixel_chunk_duration) + indices = self.video_tokenizer.fsq_quantizer.codes_to_indices(quantized_out.permute(0, 2, 3, 4, 1)) + + # Flatten the indices + indices = rearrange(indices, "B T H W -> B (T H W)") + + # tokenizer_offset tells what offset should be added to the tokens. + # This is needed for vocab expansion. + indices += self.tokenizer_config.video_tokenizer.tokenizer_offset + + # Add begin and end of video tokens + bov_token = self.video_special_tokens["<|begin_of_video|>"] + eov_token = self.video_special_tokens["<|end_of_video|>"] + + # Append bov and eov tokens + if self.tokenizer_config.add_special_tokens: + for i in range(batch_size): + video_tokens.append([bov_token] + indices[i].tolist() + [eov_token]) + else: + if self.training_type == "text_to_video": + for i in range(batch_size): + video_tokens.append([bov_token] + indices[i].tolist()) + else: + for i in range(batch_size): + video_tokens.append(indices[i].tolist()) + assert ( + len(video_tokens[-1]) == self.tokenizer_config.video_tokenizer.max_seq_len + ), f"Expected {self.tokenizer_config.video_tokenizer.max_seq_len} tokens, got {len(video_tokens[-1])}; video shape: {videos.shape}" + + return video_tokens + + def tokenize(self, data_batch: dict): + r"""Function to tokenize data_dict. + Args: + data_batch (dict): Input data dict + Returns: + tokens (torch.LongTensor): Token tensor dict + """ + + if ( + self.training_type in ["text_only", "image_text_interleaved"] + and not self.tokenizer_config.text_tokenizer.tokenize_here + ): + # In case of pre-computed tokens, just return the data_batch + return data_batch["tokens"], None + + # Online tokenization + tokens = [] + token_boundaries = defaultdict(list) + + # Obtain maximum sequence length + max_text_seq_len = -1 + max_visual_seq_len = -1 + + if self.training_type in ["text_to_video", "video_to_video"]: + max_visual_seq_len = self.tokenizer_config.video_tokenizer.max_seq_len + + # If max visual sequence length is specified, make sure that text is clipped so that + # the full video/image is always seen. + if max_visual_seq_len > -1: + if self.tokenizer_config.add_special_tokens: + max_visual_seq_len = max_visual_seq_len + 2 # Two special tokens is for [bov, eov] or [boi, eoi] token + elif self.training_type == "text_to_video": + max_visual_seq_len = max_visual_seq_len + 1 + else: + max_visual_seq_len = max_visual_seq_len + assert ( + max_visual_seq_len <= self.total_seq_len + ), f"max_visual_seq_len ({max_visual_seq_len}) is greater that total sequence length ({self.total_seq_len})" + max_text_seq_len = self.total_seq_len - max_visual_seq_len + + # Tokenize the text + if ( + "text" in self.training_type + and self.text_tokenizer is not None + and self.tokenizer_config.text_tokenizer.tokenize_here + ): + key = self.tokenizer_config.text_tokenizer.data_key + batch_size = len(data_batch[key]) + assert key in data_batch, f"Key {key} should be present in data for text tokenizer" + tokens = self._tokenize_text(data_batch["caption"], max_text_seq_len) + + for i in range(batch_size): + token_boundaries["text"].append((0, len(tokens[i]))) + else: + tokens = [] + batch_size = None + + # Tokenize the class label + if "class" in self.training_type and self.tokenizer_config.class_tokenizer is not None: + key = self.tokenizer_config.class_tokenizer.data_key + assert key in data_batch, f"Key {key} should be present in data for class tokenizer" + batch_size = len(data_batch[key]) if batch_size is None else batch_size + tokens_class = self._tokenize_class(data_batch[key]) + if len(tokens) == 0: + tokens = tokens_class + for i in range(batch_size): + token_boundaries["class"].append((0, len(tokens[i]))) + else: + for i in range(batch_size): + token_boundaries["class"].append((len(tokens[i]), len(tokens[i]) + len(tokens_class[i]))) + tokens[i] = tokens[i] + tokens_class[i] + + # Tokenize the video + if self.video_tokenizer is not None and self.tokenizer_config.video_tokenizer.tokenize_here: + key = self.tokenizer_config.video_tokenizer.data_key + assert key in data_batch, f"Key {key} should be present in data for video tokenizer" + batch_size = len(data_batch[key]) if batch_size is None else batch_size + + pixel_chunk_duration = ( + None # If not specified, we assume it's a video dataset and use the default chunk duration + ) + dataset_name = data_batch.get("dataset_name", None) + if dataset_name is not None and dataset_name.startswith("image"): + # If it's an image dataset, we use a pixel chunk duration of 1 + pixel_chunk_duration = 1 + tokens_video = self._tokenize_video(data_batch[key], pixel_chunk_duration=pixel_chunk_duration) + if len(tokens) == 0: + tokens = tokens_video + for i in range(batch_size): + token_boundaries["video"].append((0, len(tokens[i]))) + # [B,] each entry is ((0, len(tokens[i]))) + else: + for i in range(batch_size): + token_boundaries["video"].append((len(tokens[i]), len(tokens[i]) + len(tokens_video[i]))) + tokens[i] = tokens[i] + tokens_video[i] + + # Combine the tokens and do padding + max_seq_len_in_batch = max([len(token) for token in tokens]) + if self.pad_to_multiple_of is not None: + # Pad the sequence length to the nearest multiple of pad_to_multiple_of + max_seq_len_in_batch = ((max_seq_len_in_batch - 1) // self.pad_to_multiple_of + 1) * self.pad_to_multiple_of + pad_to_len = min(max_seq_len_in_batch, self.total_seq_len) + for i in range(len(tokens)): + if len(tokens[i]) < pad_to_len: + tokens[i] = tokens[i] + [self.pad_id] * (pad_to_len - len(tokens[i])) + else: + tokens[i] = tokens[i][0:pad_to_len] + + # Convert it to long tensor + tokens = torch.LongTensor(tokens) + return tokens, token_boundaries diff --git a/cosmos_predict1/autoregressive/tokenizer/utils.py b/cosmos_predict1/autoregressive/tokenizer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dd58c7830e60e5a09a38b991ccb5fef3b13293 --- /dev/null +++ b/cosmos_predict1/autoregressive/tokenizer/utils.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +from einops import pack, rearrange, unpack + + +def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size = x.shape[0] + return rearrange(x, "b c t h w -> (b t) c h w"), batch_size + + +def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor: + return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + + +def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size, height = x.shape[0], x.shape[-2] + return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height + + +def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor: + return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height) + + +def cast_tuple(t: Any, length: int = 1) -> Any: + return t if isinstance(t, tuple) else ((t,) * length) + + +def replication_pad(x): + return torch.cat([x[:, :, :1, ...], x], dim=2) + + +def divisible_by(num: int, den: int) -> bool: + return (num % den) == 0 + + +def is_odd(n: int) -> bool: + return not divisible_by(n, 2) + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +class CausalNormalize(torch.nn.Module): + def __init__(self, in_channels, num_groups=1): + super().__init__() + self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.num_groups = num_groups + + def forward(self, x): + # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose. + # All new models should use num_groups=1, otherwise causality is not guaranteed. + if self.num_groups == 1: + x, batch_size = time2batch(x) + return batch2time(self.norm(x), batch_size) + return self.norm(x) + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def round_ste(z: torch.Tensor) -> torch.Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() diff --git a/cosmos_predict1/autoregressive/train.py b/cosmos_predict1/autoregressive/train.py new file mode 100644 index 0000000000000000000000000000000000000000..aa95d4e4738e8e8c6147cdab9970f92b5e192597 --- /dev/null +++ b/cosmos_predict1/autoregressive/train.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import os + +from loguru import logger as logging +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from omegaconf import OmegaConf + +from cosmos_predict1.utils import misc +from cosmos_predict1.utils.config_helper import get_config_module, override +from cosmos_predict1.utils.lazy_config import instantiate +from cosmos_predict1.utils.lazy_config.lazy import LazyConfig + + +@misc.timer("instantiate LLM") +def instantiate_model(config, trainer) -> None: + model_parallel_cuda_manual_seed(config.trainer.seed) + model = instantiate(config.model) + if not config.model["model_config"].set_parallel_mode: + misc.set_random_seed(seed=config.trainer.seed, by_rank=True) + + return model + + +@logging.catch(reraise=True) +def launch(config, args: argparse.Namespace) -> None: + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + trainer = config.trainer.type(config) + # Create the model + model = instantiate_model(config, trainer) + + model.on_model_init_end() + dataloader_train = instantiate(config.dataloader_train) + dataloader_val = instantiate(config.dataloader_val) + trainer.train( + model, + dataloader_train, + dataloader_val, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Training") + parser.add_argument( + "--config", default="projects.cosmos.ar.v1.configs.train_openhermes", help="Path to the config file" + ) + parser.add_argument("--cluster", default=None, help="Cluster name") + parser.add_argument( + "opts", + help="""Modify config options at the end of the command. For Yacs configs, use + space-separated "PATH.KEY VALUE" pairs. + For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--dryrun", + action="store_true", + help="Do a dry run without training. Useful for debugging the config.", + ) + args = parser.parse_args() + config = importlib.import_module(get_config_module(args.config)).make_config() + config = override(config, args.opts) + if args.dryrun: + os.makedirs(config.job.path_local, exist_ok=True) + LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") + print(OmegaConf.to_yaml(OmegaConf.load(f"{config.job.path_local}/config.yaml"))) + else: + # Launch the training job. + launch(config, args) diff --git a/cosmos_predict1/autoregressive/trainer.py b/cosmos_predict1/autoregressive/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7684384cab86e58071d9eb4da8fc732504f4811b --- /dev/null +++ b/cosmos_predict1/autoregressive/trainer.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import signal + +import torch +import torch.distributed as dist +import torch.utils.data +from megatron.core import parallel_state + +from cosmos_predict1.checkpointer.tp import Checkpointer as TensorParallelCheckpointer +from cosmos_predict1.utils import distributed, ema, log, misc +from cosmos_predict1.utils.checkpointer import Checkpointer +from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +class Trainer(Trainer): + def __init__(self, config): + super(Trainer, self).__init__(config) + if config.trainer.distributed_parallelism == "ddp": + if parallel_state.get_tensor_model_parallel_world_size() > 1: + self.checkpointer = TensorParallelCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) + log.critical("Using Tensor Parallelism Checkpointer") + else: + self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) + + elif config.trainer.distributed_parallelism == "fsdp": + self.checkpointer = FSDPCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) + else: + raise ValueError(f"Unsupported distributed parallelism: {config.trainer.distributed_parallelism}") + + """ + Modify the original trainer to log average loss (averaging across all devices and gradient accumulation) + """ + + def train( + self, + model: Model, + dataloader_train: torch.utils.data.DataLoader, + dataloader_val: torch.utils.data.DataLoader, + ) -> None: + """The training function. + + Args: + model (Model): The PyTorch model. + dataloader_train (torch.utils.data.DataLoader): The training data loader. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + """ + # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. + model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore + log.info(f"Model Architecture:\n {model}") + model.on_train_start(self.config.trainer.memory_format) + # Initialize the optimizer and scheduler. + self.callbacks.on_optimizer_init_start() + + optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) + + grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) + self.callbacks.on_optimizer_init_end() + # Load the model checkpoint and get the starting iteration number. + iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) + # Set the scheduler to the current iteration. + scheduler.last_epoch = iteration + scheduler._step_count = iteration + 1 + + grad_accum_iter = 0 + log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") + if self.config.trainer.distributed_parallelism == "ddp": + # Create a DDP model wrapper. + model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) + elif self.config.trainer.distributed_parallelism == "fsdp": + model_ddp = model + else: + raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") + log.info("Starting training...") + self.callbacks.on_train_start(model, iteration=iteration) + # Initial validation. + if self.config.trainer.run_validation and iteration == 0: + self.validate(model, dataloader_val, iteration=iteration) + _end_training = False + self.callbacks.on_before_dataloading(iteration) + accumulated_loss = 0.0 + + while True: + dataloader_train_iter = iter(dataloader_train) + while True: + self.callbacks.on_before_dataloading(iteration) + try: + data_batch = next(dataloader_train_iter) + except StopIteration: + break + self.callbacks.on_after_dataloading(iteration) + # If max_iter is reached, exit the training loop. + if iteration >= self.config.trainer.max_iter: + _end_training = True + break + # Move all tensors in the data batch to GPU device. + + data_batch = misc.to(data_batch, device="cuda") + # The actual training step. + self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) + model_ddp.train() + output_batch, loss, grad_accum_iter = self.training_step( + model_ddp, + optimizer, + scheduler, + grad_scaler, + data_batch, + iteration=iteration, + grad_accum_iter=grad_accum_iter, + ) + + # Accumulate loss + accumulated_loss += loss.detach() + + # If the gradients are still being accumulated, continue to load the next training batch. + if grad_accum_iter != 0: + if self.enable_one_logger: + # Callback for skipped OneLoggerCallback.on_training_step_end() + self.one_logger.on_train_batch_end(set_barrier=False) + continue + # Do the following when an actual optimizer (update) step has been made. + iteration += 1 + + # Average loss over accumulation steps + grad_accum_avg_loss = accumulated_loss / self.config.trainer.grad_accum_iter + # Average loss across all devices + device_avg_loss = grad_accum_avg_loss.clone() + dist.all_reduce(device_avg_loss, op=dist.ReduceOp.SUM) + device_avg_loss /= dist.get_world_size() + # Reset accumulation variables + accumulated_loss = 0.0 + + self.callbacks.on_training_step_end( + model, data_batch, output_batch, device_avg_loss, iteration=iteration + ) + + # self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) + + # Validation. + if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: + self.validate(model, dataloader_val, iteration=iteration) + # Save checkpoint. + if iteration % self.config.checkpoint.save_iter == 0: + self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) + # This iteration is successful; reset the timeout signal. + signal.alarm(self.config.trainer.timeout_period) + if _end_training: + break + log.success("Done with training.") + self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) + self.callbacks.on_train_end(model, iteration=iteration) + self.checkpointer.finalize() + distributed.barrier() + self.callbacks.on_app_end() + + def training_step( + self, + model_ddp: torch.nn.Module | distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + data: dict[str, torch.Tensor], + iteration: int = 0, + grad_accum_iter: int = 0, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: + """The training step. + + Args: + model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare + module, depending on whether distributed training is enabled or not. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + grad_accum_iter (int): Number of gradient accumulation iterations. + + Returns: + output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). + loss (torch.Tensor): The total loss of the training data batch. + """ + # Only let DDP sync gradient at the last iteration of the gradient accumulation window + with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): + with self.training_timer("forward"): + output_batch, loss = model_ddp.training_step(data, iteration) + self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) + with self.training_timer("backward"): + loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) + loss_scaled.backward() + if self.config.trainer.distributed_parallelism == "ddp": + model_ddp.module.on_after_backward() + else: + model_ddp.on_after_backward() + self.callbacks.on_after_backward(model_ddp, iteration=iteration) + grad_accum_iter += 1 + if grad_accum_iter == self.config.trainer.grad_accum_iter: + with self.training_timer("optimizer_step"): + self.callbacks.on_before_optimizer_step( + model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration + ) + grad_scaler.step(optimizer) + grad_scaler.update() + scheduler.step() + self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) + if self.config.trainer.distributed_parallelism == "ddp": + model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) + else: + model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) + optimizer.zero_grad(set_to_none=True) + grad_accum_iter = 0 + return output_batch, loss, grad_accum_iter + + @torch.no_grad() + def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: + """Validate on the full validation dataset. + + Args: + model (Model): The PyTorch model. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + iteration (int): Current iteration number. + """ + self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) + model.eval() + # Evaluate on the full validation set. + with ema.ema_scope(model, enabled=getattr(model.config.ema, "enabled", False)): + for val_iter, data_batch in enumerate(dataloader_val): + if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: + break + data_batch = misc.to(data_batch, device="cuda") + self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) + output_batch, loss = model.validation_step(data_batch, iteration) + self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) + self.callbacks.on_validation_end(model, iteration=iteration) diff --git a/cosmos_predict1/autoregressive/training/model.py b/cosmos_predict1/autoregressive/training/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3fccb2c37a8c703470005ef46a9a5b6ede146ee5 --- /dev/null +++ b/cosmos_predict1/autoregressive/training/model.py @@ -0,0 +1,1240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import json +import os +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple + +import torch +import torch.nn.functional as F +from megatron.core import InferenceParams, ModelParallelConfig, parallel_state +from safetensors.torch import load_file +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy, StateDictType +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy +from torch.nn.modules.module import _IncompatibleKeys + +from cosmos_predict1.autoregressive.configs.base.model import TrainingModelConfig as ModelConfig +from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig +from cosmos_predict1.autoregressive.modules.mm_projector import MultimodalProjector +from cosmos_predict1.autoregressive.networks.vit import VisionTransformer, get_vit_config + +# from cosmos_predict1.autoregressive.training.networks.transformer_medusa import TransformerMedusa +from cosmos_predict1.autoregressive.tokenizer.tokenizer import DiscreteMultimodalTokenizer +from cosmos_predict1.autoregressive.training.networks.transformer import ( + Transformer, + TransformerBlock, + TransformerBlockTE, +) +from cosmos_predict1.autoregressive.utils.checkpoint import ( + get_partial_state_dict, + maybe_convert_checkpoint_to_backend, + obtain_tensor_parallel_state_dict, + process_state_dict, + substrings_to_ignore, +) +from cosmos_predict1.autoregressive.utils.misc import random_dropout +from cosmos_predict1.autoregressive.utils.parallel import broadcast_data_batch_in_tp_cp_group, get_batch_on_this_cp_rank +from cosmos_predict1.autoregressive.utils.sampling import ( + decode_n_tokens, + decode_one_token, + prefill, + sample_top_k, + sample_top_p, +) +from cosmos_predict1.diffusion.training.utils.fsdp_helper import apply_fsdp_checkpointing, hsdp_device_mesh +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.misc import download_from_s3_with_cache, sync_s3_dir_to_local +from cosmos_predict1.utils.model import Model + + +class AutoRegressiveTrainingModel(Model): + """ + A class to build and use a Llama model for text generation. + + Methods: + build: Build a Llama instance by initializing and loading a model checkpoint. + generate: Generate text sequences based on provided prompts using the language generation model. + """ + + def __init__( + self, + model: Transformer, + tokenizer: DiscreteMultimodalTokenizer, + config: ModelConfig, + model_parallel: ModelParallelConfig = None, + vision_encoder: VisionTransformer = None, + mm_projector: MultimodalProjector = None, + ): + """ + Initialize the Llama instance with a model and tokenizer. + + Args: + model (Transformer): The Transformer model for text generation. + tokenizer (Tokenizer): The tokenizer for encoding and decoding text. + config (Config): The configuration for the Llama model. + """ + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.config = config + self.precision = self.model.precision + self.vision_encoder = vision_encoder + self.mm_projector = mm_projector + assert (self.vision_encoder is None) == (self.mm_projector is None), ( + "vision_encoder and mm_projector should be " "both None or not None simultaneously" + ) + self.model_parallel = model_parallel + self.monitor_output_logits = False + self.inference_params = None + # self.insert_medusa_head = self.config.insert_medusa_head + + if self.config.freeze_vision_encoder and vision_encoder is not None: + for param in self.vision_encoder.parameters(): + param.requires_grad = False + log.critical("Vision encoder parameters are frozen.") + + num_params = self.get_num_params() + log.info(f"Number of model parameters: {round(num_params / 1e9, 3)}B") + + def get_num_params( + self, + ) -> int: + """ + Return the number of parameters in the model. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + def training_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + broadcast_data_batch_in_tp_cp_group(data_batch) + # get the context embedding and mask + context = data_batch.get("context", None) + context_mask = data_batch.get("context_mask", None) + if context is not None: + if self.config.embedding_dropout > 0: + context = random_dropout( + context, + self.config.embedding_dropout, + ) + context = misc.to(context, device="cuda") + if context_mask is not None: + context_mask = misc.to(context_mask, device="cuda") + action = data_batch.get("action", None) + if action is not None: + action = misc.to(action, device="cuda") + # Input tokens + tokens, token_boundaries = self.tokenizer.tokenize(data_batch) + tokens = misc.to(tokens, device="cuda") + # Tokens to predict + labels = data_batch.get("labels", None) + # Token Mask (Note: this is not attention mask) + masks = data_batch.get("token_mask", None) + apply_token_mask = masks is not None + if masks is None: + masks = torch.ones_like(tokens, dtype=torch.bool) + masks = misc.to(masks, device="cuda") + assert ( + data_batch.get("labels", None) is None or apply_token_mask + ), "The code is not tested for the case when both labels and token_mask are provided." + + if self.config.ignore_first_num_tokens > 0: + assert self.config.ignore_first_num_tokens < masks.shape[1] + masks[:, : self.config.ignore_first_num_tokens] = False + seq_len = tokens.shape[1] + + # Boradcast inputs to TP and CP ranks, alternatively we can use the `_broadcast` function from cosmos/diffusion/v1 + # Currently we only handled video tokens (with label and mask) and text tokens (with mask), action and other inputs might also need to be handled + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.model.enable_context_parallel(cp_group) + tokens = get_batch_on_this_cp_rank(tokens) + masks = get_batch_on_this_cp_rank(masks) + if labels is not None: + labels = get_batch_on_this_cp_rank(labels) + if self.vision_encoder is None: + logits = self.model.forward( + tokens=tokens, + input_pos=None, + context=context, + context_mask=context_mask, + action=action, + total_seq_len=seq_len, + ) + else: + assert "images" in data_batch + images = data_batch["images"] + if images.ndim == 5: + # The shape is (batch_size, n_images_per_sample, C, H, W). Flatten the first two dimensions. + images = images.view(-1, *images.shape[2:]) + assert images.ndim == 4, f"Invalid shape: {images.shape}" + token_embeddings = self.embed_vision_language_features(tokens, images) + logits = self.model.forward( + token_embeddings=token_embeddings, + input_pos=None, + context=context, + context_mask=context_mask, + action=action, + total_seq_len=seq_len, + ) + + if labels is None: + # For auto-regressive models, the labels are the same as the + # input tokens shifted by one position + logits = logits[:, :-1] + masks = masks[:, :-1] + labels = tokens[:, 1:].clone() + + batch_size = tokens.shape[0] + # Apply ignore_index + for sample_num in range(batch_size): + if self.tokenizer.training_type == "text_to_video": + # For text-to-video training, we do not compute the loss of text part + # Hence, we set the labels of text tokens to that of ignore_index + if len(token_boundaries["text"]) > 0: + labels[sample_num][0 : token_boundaries["text"][sample_num][1] - 1] = self.tokenizer.ignore_index + elif self.tokenizer.training_type == "class_to_image": + # For class-to-image training, we do not compute the loss of class part + # Hence, we set the labels of class tokens to that of ignore_index + labels[sample_num][0 : token_boundaries["class"][sample_num][1] - 1] = self.tokenizer.ignore_index + + ignore_index = self.tokenizer.ignore_index + if self.config.ignore_first_num_tokens > 0 or apply_token_mask: + labels[~masks] = ignore_index + + output_batch = { + "encode_tokens": tokens, + "logits": logits.detach(), + "labels": labels.detach(), + "ignore_index": ignore_index, + } + + if self.monitor_output_logits: + self.gather_output_logits_stats(logits, labels, output_batch, ignore_index) + + logits = logits.flatten(0, 1) + labels = labels.flatten(0, 1) + + # Main cross entropy loss + ce_loss = F.cross_entropy( + input=logits, + target=labels, + ignore_index=ignore_index, # ignore prompt (turn prompt tokens into pad_id here) + ) + + # Z-loss + log_z = torch.logsumexp(logits, dim=-1) # shape: [B, seq_len] + z_loss = self.config.z_loss_coeff * (log_z**2).mean() + + # Combined loss + total_loss = ce_loss + z_loss + + return output_batch, total_loss # skip returning output logits + + @torch.no_grad() + def validation_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Perform a validation step for the model, which is the same as the training step (but without backpropagation). + """ + return self.training_step(data_batch, iteration) + + @torch.no_grad() + def gather_output_logits_stats( + self, logits: torch.Tensor, labels: torch.Tensor, output_batch: Dict, ignore_index: int = None + ): + """ + Gather statistics of the output logits, including mean, norm, and max values. + """ + bs, seq_len, dim = logits.shape + logits = logits.reshape(-1, dim) + if ignore_index is not None: + select_index = labels.view(-1) != ignore_index + acc = labels.view(-1)[select_index] == logits.argmax(dim=1)[select_index] + acc = acc.float().mean().view(-1, 1) + + logits = logits[select_index] + output_batch.update( + { + "logits_mean": logits.mean(dim=1).detach(), + "logits_norm": torch.linalg.vector_norm(logits, dim=1).detach(), + "logits_max": logits.max(dim=1).values.detach(), + "acc": acc.detach() * 100, + } + ) + + @torch.no_grad() + def image_encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encode the image input state to continuous latent and discrete indices. + """ + latent, indices = self.tokenizer.image_tokenizer.encode(state) + return latent, indices + + @torch.no_grad() + def image_decode(self, indices: torch.Tensor) -> torch.Tensor: + """ + Decode the discrete indices to RGB images. + """ + return self.tokenizer.image_tokenizer.decode(indices) + + @torch.no_grad() + def video_encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encode the video input state to continuous latent and discrete indices. + """ + latent, indices = self.tokenizer.video_tokenizer.encode(state) + return latent, indices + + @torch.no_grad() + def video_decode(self, indices: torch.Tensor) -> torch.Tensor: + """ + Decode the discrete indices to RGB videos. + """ + if self.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap > 0: + return self.tokenizer.video_tokenizer.decode_with_overlap( + indices, temporal_overlap=self.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap + ) + else: + return self.tokenizer.video_tokenizer.decode(indices) + + @staticmethod + def load_llm_checkpoint( + ckpt_path: str = "", + model: Transformer = None, + **kwargs, + ) -> None: + """ + Load a LLM checkpoint from the specified path. + """ + with misc.timer(f"loading checkpoint from {ckpt_path}"): + checkpoint = torch.load( + ckpt_path, + map_location="cpu", + mmap=True, # load the checkpoint in memory-mapped mode + ) + llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint + llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") + llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") + with misc.timer("loading state_dict into model"): + missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) + + @staticmethod + def build( + seed: int = 1, + train_from_scratch: bool = False, + model_config: ModelConfig = ModelConfig(), + fsdp_checkpointer: Any = None, + tokenizer_config: TokenizerConfig = None, + model_parallel: ModelParallelConfig = None, + shard_checkpoint: bool = True, + download_rank_sync: bool = True, + **kwargs, + ) -> "AutoRegressiveTrainingModel": + """ + Build a Llama instance by initializing and loading a model checkpoint. + + Args: + seed (int, optional): Random seed for reproducibility. Defaults to 1. + train_from_scratch (bool, optional): Flag indicating whether to train the model from scratch. Defaults to False. + model_config (ModelConfig, optional): The model configuration for the Llama instance. Defaults to ModelConfig(). + fsdp_checkpointer (Any, optional): The FSDP checkpointer for the Llama instance. Defaults to None. + tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the Llama instance. Defaults to None. + shard_checkpoint (bool, optional): Whether to split the checkpoint by Tensor Parallelism before loading. Defaults to False. + download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True. + Returns: + Llama: An instance of the Llama class with the loaded model and tokenizer. + + Raises: + AssertionError: If there are no checkpoint files in the specified directory. + + Note: + This method sets the device to CUDA and loads the pre-trained model and tokenizer. + """ + tensor_parallel_size = 1 if model_parallel is None else model_parallel.tensor_model_parallel_size + # seed must be the same in all processes + torch.manual_seed(seed) + + # Initialize model configuration parameters + llama_params = {} + + # Load checkpoint and model parameters + if not train_from_scratch: + if model_config.ckpt_path is None: + # If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir + ckpt_dir = sync_s3_dir_to_local( + s3_dir=model_config.ckpt_dir, + s3_credential_path=model_config.s3_credential_path, + cache_dir=model_config.cache_dir, + ) + + # We prioritize safetensors version over the pytorch version, since the former is + # much faster for checkpoint loading. + checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) + if len(checkpoints) == 0: + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert ( + len(checkpoints) == 1 + ), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)" + ckpt_path = str(checkpoints[0]) # Assuming single checkpoint for non-parallel case + + if os.path.exists(Path(ckpt_dir) / "params.json"): + with open(Path(ckpt_dir) / "params.json", "r") as f: + llama_params = json.loads(f.read()) + else: + log.info( + f"No params.json found in the checkpoint directory ({ckpt_dir}). " + f"Using default model config." + ) + + else: + # If ckpt_path is provided, we load the model from the specified path, + # and use the default model configuration + ckpt_path = download_from_s3_with_cache( + s3_path=model_config.ckpt_path, + s3_credential_path=model_config.s3_credential_path, + cache_dir=model_config.cache_dir, + rank_sync=download_rank_sync, + ) + + for key, value in llama_params.items(): + # Override the default model configuration with the parameters from the checkpoint + setattr(model_config, key, value) + + with misc.timer(f"loading checkpoint from {ckpt_path}"): + if ckpt_path.endswith("safetensors"): + # Load with safetensors API + checkpoint = load_file(ckpt_path, device="cpu") + else: + # The pytorch version + checkpoint = torch.load( + ckpt_path, + map_location="cpu", + mmap=True, # load the checkpoint in memory-mapped mode + ) + llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint + + # If the checkpoint backend is different from the model backend, convert the checkpoint + # to be compatible with the model backend + # If shard_checkpoint is True, the loaded checkpoint is the whole model checkpoint (will be sharded later) + # instead of a tensor-parallel sharded checkpoint + llm_checkpoint = maybe_convert_checkpoint_to_backend( + llm_checkpoint, + target_backend=model_config.backend, + model_config=model_config, + tensor_parallel_size=tensor_parallel_size if not shard_checkpoint else 1, + is_tensor_parallel_shard=tensor_parallel_size > 1 and not shard_checkpoint, + ) + if model_config.vision_encoder is not None: + # For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']` + # and `checkpoint['mm_projector']` are both for those weights + # For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights + if "vision_encoder" in checkpoint: + log.info("Using pretrained vision_encoder") + vit_checkpoint = checkpoint["vision_encoder"] + else: + log.info("Using fine-tuned vision_encoder") + vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.") + vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.") + if "mm_projector" in checkpoint: + log.info("Using pretrained mm_projector") + projector_checkpoint = checkpoint["mm_projector"] + else: + log.info("Using fine-tuned mm_projector") + projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.") + projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.") + assert ( + len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0 + ), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector." + + tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) + + precision = getattr(torch, model_config.precision) + torch.set_default_dtype(precision) + log.info(f"Setting torch default dtype to {precision}") + + # if model_config.insert_medusa_head: + # model = TransformerMedusa( + # params=model_config, + # model_parallel=model_parallel, + # tokenizer_config=tokenizer_config, + # init_weights=train_from_scratch, + # ) + # else: + model = Transformer( + params=model_config, + model_parallel=model_parallel, + tokenizer_config=tokenizer_config, + init_weights=train_from_scratch, + ) + model_kwargs = {} + # [Optional] Initialize vision encoder and multimodal projector (for vision-language tasks) + if model_config.vision_encoder is not None: + assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided." + vit_config = get_vit_config(model_config.vision_encoder) + vision_encoder = VisionTransformer.build( + vit_config, + hidden_dropout=model_config["hidden_dropout"], + attention_dropout=model_config["attention_dropout"], + set_parallel_mode=model_config["set_parallel_mode"], + model_parallel=model_parallel, + attention_tp=tensor_parallel_size > 1, + ) + + mm_projector = MultimodalProjector( + mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"] + ) + model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector}) + + # Perform vocab expansion + if tokenizer.vocab_size > model.vocab_size: + log.info(f"Expanding vocab size to {tokenizer.vocab_size}") + # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer, + expand_output_layer = not (tokenizer.training_type == "text_to_video") + model.expand_vocab(tokenizer.vocab_size, init_method="gaussian", expand_output_layer=expand_output_layer) + + if not train_from_scratch: + if shard_checkpoint: + # Shard the checkpoint according to tensor parallelism. + with misc.timer("sharding checkpoint according to tensor parallelism"): + if model_parallel is not None: + assert model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] + llm_checkpoint = obtain_tensor_parallel_state_dict( + llm_checkpoint, + tensor_parallel_size=tensor_parallel_size, + tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), + model_config=model_config, + ) + if model_config.vision_encoder is not None: + # Shard vision encoder and multimodal projector weights + vit_checkpoint = obtain_tensor_parallel_state_dict( + vit_checkpoint, + tensor_parallel_size=tensor_parallel_size, + tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), + model_config=vit_config, + ) + + if model_config.vision_encoder is not None: + # Take the LLM weights (starting with "model.") from the VLM checkpoint + llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") + # Remove the "model." prefix in the state_dict + llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") + with misc.timer("loading state_dict into model"): + missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) + # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage) + missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + + if model_config.vision_encoder is not None: + # Load vision encoder and multimodal projector weights + vision_encoder.load_state_dict(vit_checkpoint) + mm_projector.load_state_dict(projector_checkpoint) + if model_config.vision_encoder_in_channels != 3: + vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels) + + model = model.to(precision) # ensure model parameters are in the correct precision + log.info(f"Model config: {model_config}") + + # if model_config.insert_medusa_head: + # from projects.cosmos.ar.v1.model_medusa import LlamaMedusa + + # model_class = LlamaMedusa + # else: + model_class = AutoRegressiveTrainingModel + if model_config.fsdp_enabled: + raise NotImplementedError("FSDP is not implemented for AutoRegressiveTrainingModel") + # model_kwargs["fsdp_checkpointer"] = fsdp_checkpointer + # model_class = FSDPLlama + return model_class(model, tokenizer, model_config, **model_kwargs) + + @torch.inference_mode() + def generate( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + top_k: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + logit_clipping_range: list = [], + seed: int = 0, + images: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + action: Optional[torch.Tensor] = None, + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + """ + Generate text sequences based on provided prompts using the language generation model. + + Args: + prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + top_k (int, optional): Top-k value for top-k sampling. Defaults to None. If not None, top-k sampling will be used instead of top-p sampling. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. + + Note: + This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified." + if top_p is not None: + log.info(f"Using top-p sampling with p={top_p} and temperature={temperature}") + elif top_k is not None: + log.info(f"Using top-k sampling with k={top_k} and temperature={temperature}") + else: + log.info("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") + + self.model.set_inference_flag(True) + misc.set_random_seed(seed) + # Initialization and Assertions + if isinstance(self.model.params, list): + # During training, model.params is a list + log.info( + f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" + ) + params = self.config + else: + params = self.model.params + bsz = len(prompt_tokens) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + if self.config.backend == "transformer_engine": + self.inference_params = InferenceParams( + max_batch_size=params.max_batch_size, max_sequence_length=params.max_seq_len + ) + + # Calculate Prompt Lengths + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.max_seq_len + total_len = params.max_seq_len + assert ( + max_gen_len + max_prompt_len <= total_len + ), f"max_gen_len + max_prompt_len={max_gen_len + max_prompt_len} exceeds max_seq_len={total_len}" + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + + # Fill tokens tensor with prompt tokens + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + + # Flag to check if image embeddings have been passed to the model - we only need to pass them once + # since we have KV cache. + passed_image_embeddings = False + + # If all prompts are of max length, compute initial logits and logprobs + if min_prompt_len == total_len: + input_pos = torch.arange(tokens.shape[1], dtype=torch.long, device="cuda") + if images is None: + logits = self.model.forward( + tokens=tokens, + input_pos=input_pos, + inference_params=self.inference_params, + context=context, + context_mask=context_mask, + action=action, + ) + else: + token_embeddings = self.embed_vision_language_features(tokens, images) + logits = self.model.forward( + token_embeddings=token_embeddings, + input_pos=input_pos, + inference_params=self.inference_params, + context=context, + context_mask=context_mask, + action=action, + ) + passed_image_embeddings = True + token_logprobs = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens, + reduction="none", + ignore_index=pad_id, + ) + + stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens), dtype=torch.long, device="cuda") + + # Main generation loop + log.info(f"Start generating the next {total_len - min_prompt_len} tokens. This will take a while..") + for cur_pos in range(min_prompt_len, total_len): + input_pos = torch.arange(prev_pos, cur_pos, dtype=torch.long, device="cuda") + if images is not None and not passed_image_embeddings: + token_embeddings = self.embed_vision_language_features(tokens[:, prev_pos:cur_pos], images) + logits = self.model.forward( + token_embeddings=token_embeddings, + input_pos=input_pos, + inference_params=self.inference_params, + context=context, + context_mask=context_mask, + action=action, + ) + passed_image_embeddings = True + else: + logits = self.model.forward( + tokens=tokens[:, prev_pos:cur_pos], + input_pos=input_pos, + inference_params=self.inference_params, + context=context, + context_mask=context_mask, + action=action, + ) + + if self.config.backend == "transformer_engine": + self.inference_params.sequence_len_offset += logits.shape[1] + + # Apply temperature scaling and nucleus sampling + if len(logit_clipping_range) > 0: + min_clip_index = logit_clipping_range[0] + max_clip_index = logit_clipping_range[1] + logits_clipped = logits[:, :, min_clip_index:max_clip_index] + else: + logits_clipped = logits + min_clip_index = 0 + + if temperature > 0: + if top_p is not None: + next_token = sample_top_p(logits_clipped, temperature=temperature, top_p=top_p)[0] + else: + next_token = sample_top_k(logits_clipped, temperature=temperature, top_k=top_k)[0] + else: + next_token = torch.argmax(logits_clipped[:, -1, :], dim=-1) + + next_token += min_clip_index + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) + tokens[:, cur_pos] = next_token + # Calculate log probabilities if requested + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + # Check if end-of-sequence token is reached + eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) + prev_pos = cur_pos + # Break the loop if all sequences have reached an end-of-sequence token + if all(eos_reached): + log.info(f"Reach end of sequence, current pos: {cur_pos}; maximum pos: {total_len}") + break + # Convert log probabilities to list if required + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + + # Process and collect the output tokens and log probabilities + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + probs = None + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to after eos tok if any + for stop_token in self.tokenizer.stop_tokens: + try: + eos_idx = toks.index(stop_token) + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + except ValueError: + pass + out_tokens.append(toks) + out_logprobs.append(probs) + self.model.set_inference_flag(False) + return (out_tokens, out_logprobs if logprobs else None) + + @torch.no_grad() + def fast_generate( + self, + prompt_tokens: List[List[int]] | torch.Tensor, + max_gen_len: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + num_gen_seq: int = 1, + logprobs: bool = False, + echo: bool = False, + seed: int = 0, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + action: Optional[torch.Tensor] = None, + compile_decode: bool = True, + compile_prefill: bool = False, + verbose: bool = True, + stop_tokens: Optional[Set[int]] = None, + ): + """ + Fast auto-regressive generation. Currently only supports input batch size = 1. + Args: + prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len). + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_k (int, optional): Top-k value for top-k sampling. Defaults to None. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. + num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + logit_clipping_range (list, optional): Range of logits to clip. Defaults to []. + seed (int, optional): Random seed for reproducibility. Defaults to 0. + compile_decode (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True. + compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. + verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False. + """ + assert ( + top_p is None or top_k is None + ), f"Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" + if top_p is not None: + log.info(f"Using top-p sampling with p={top_p} and temperature={temperature}") + elif top_k is not None: + log.info(f"Using top-k sampling with k={top_k} and temperature={temperature}") + else: + log.info("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + # Experimental features to reduce compilation times, will be on by default in future + torch._inductor.config.fx_graph_cache = True + # torch._functorch.config.enable_autograd_cache = True + + self.model.set_inference_flag(True) + misc.set_random_seed(seed) + + assert not logprobs, "logprobs are not supported for fast_generate yet" + # Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags + if compile_decode and not getattr(self, "inference_decode_compiled", False): + self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + self.inference_decode_compiled = True + log.critical("Compiled decode_one_token function. Note: the first run will be slower due to compilation") + if compile_prefill and not getattr(self, "inference_prefill_compiled", False): + self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + self.inference_prefill_compiled = True + log.critical("Compiled prefill function. Note: the first run will be slower due to compilation") + + if not hasattr(self, "decode_one_token"): + self.decode_one_token = decode_one_token + if not hasattr(self, "prefill"): + self.prefill = prefill + + # Initialization and Assertions + if isinstance(self.model.params, list): + # During training, model.params is a list + log.info( + f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" + ) + params = self.config + else: + params = self.model.params + if isinstance(prompt_tokens, list): + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda") + if prompt_tokens.ndim == 1: + prompt_tokens = prompt_tokens.view(1, -1) + else: + assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}" + batch_size, prompt_len = prompt_tokens.shape + total_len = min(params.max_seq_len, max_gen_len + prompt_len) + if max_gen_len + prompt_len > params.max_seq_len: + log.warning( + f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}" + ) + max_gen_len = params.max_seq_len - prompt_len + + if context_mask is not None: + context_mask = context_mask.to(dtype=torch.bool) + if context_mask.ndim == 2: + assert ( + context_mask.shape[0] == batch_size + ), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}" + # Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len] + context_mask = context_mask.view(batch_size, 1, 1, -1) + + if num_gen_seq > 1: + assert ( + batch_size == 1 + ), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts" + log.critical(f"Generating {num_gen_seq} sequences with the same prompt") + assert ( + num_gen_seq <= params.max_batch_size + ), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}" + # repeat the prompt tokens for num_gen_seq times + prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1) + assert prompt_tokens.shape == ( + num_gen_seq, + prompt_len, + ), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}" + batch_size = len(prompt_tokens) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device) + empty[:, :prompt_len] = prompt_tokens + seq = empty + input_pos = torch.arange(0, prompt_len, device="cuda") + + if verbose: + prefill_start = time.time() + + # Prefill stage + next_token = self.prefill( + self.model, + prompt_tokens, + input_pos=input_pos, + temperature=temperature, + top_k=top_k, + top_p=top_p, + context=context, + context_mask=context_mask, + action=action, + ) + if verbose: + prefill_time = time.time() - prefill_start + + seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype) + input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda") + stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens + stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda") + + if verbose: + decode_start = time.time() + # Decode stage + generated_tokens = decode_n_tokens( + self.model, + next_token.view(batch_size, -1), + input_pos, + max_gen_len - 1, + temperature=temperature, + top_k=top_k, + top_p=top_p, + stop_tokens=stop_tokens, + decode_one_token_function=self.decode_one_token, + context=context, + context_mask=context_mask, + action=action, + ) + gen_len = len(generated_tokens) + if verbose: + decode_time = time.time() - decode_start + prefill_throughput = prompt_len / prefill_time + decode_throughput = gen_len / decode_time + log.info(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s") + log.info(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s") + + generated_tokens = torch.cat(generated_tokens, dim=1) + + log.critical(f"generated_tokens: {generated_tokens.shape}") + seq = seq[:, : prompt_len + 1 + gen_len] + seq[:, prompt_len + 1 :] = generated_tokens + if not echo: + seq = seq[:, prompt_len:] + return seq, None + + def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor: + """ + Embed vision and language features into a combined representation. + + Args: + input_ids (torch.Tensor): Input token IDs. + images (torch.tensor): Input images. + + Returns: + torch.Tensor: Combined vision-language features. + + Raises: + AssertionError: If vision encoder or mm projector is not initialized, + or if dimensions mismatch. + """ + # Ensure vision encoder and mm projector are initialized + assert self.vision_encoder is not None + assert self.mm_projector is not None + + # Get image token ID and validate it + image_token_id = self.vision_encoder.image_token_id + assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}" + + # Identify text and image locations in the input + text_locations = input_ids != image_token_id + image_locations = input_ids == image_token_id + + # Process text features + text_features = self.model.tok_embeddings(input_ids[text_locations]) + + # Process image features + images = images.to(device=text_features.device, dtype=text_features.dtype) + vit_outputs = self.vision_encoder(images) + image_features = self.mm_projector(vit_outputs) + + # Get dimensions + B, seq_len = input_ids.shape + N_total = B * seq_len + N_txt, D_txt = text_features.shape + N_img, N_patch, D_img = image_features.shape + + # Reshape image features + image_features = image_features.reshape(N_img * N_patch, D_img) + + # Validate dimensions + assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}" + assert ( + N_total == N_txt + N_img * N_patch + ), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}" + + # Combine text and image features + combined_features = torch.empty( + (B, seq_len, D_txt), + dtype=text_features.dtype, + device=text_features.device, + ) + combined_features[text_locations, :] = text_features + combined_features[image_locations, :] = image_features + + return combined_features + + def on_after_backward(self, iteration: int = 0): + """ + Hook after loss.backward() is called. + + This method is called immediately after the backward pass, allowing for custom operations + or modifications to be performed on the gradients before the optimizer step. + + So far, this method is used to all-reduce layernorm grads for tensor/sequence parallelism. + + Args: + iteration (int): Current iteration number. + """ + for module in self.children(): + if hasattr(module, "on_after_backward"): + module.on_after_backward(iteration) + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + """Hook before zero_grad() is called. + + Args: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + iteration (int): Current iteration number. + """ + for module in self.children(): + if hasattr(module, "on_before_zero_grad"): + module.on_before_zero_grad(optimizer, scheduler, iteration) + + @property + def fsdp_wrap_block_cls(self): + """ + Return the transformer block class to wrap with FSDP. + """ + if self.config.backend == "pytorch": + return TransformerBlock + elif self.config.backend == "transformer_engine": + return TransformerBlockTE + else: + raise ValueError(f"Unknown backend: {self.config.backend}") + + def state_dict(self, *args, **kwargs): + """ + Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). + """ + state_dict = super().state_dict(*args, **kwargs) + return process_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + """ + Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by + TransformerEngine for FP8). + """ + state_dict = process_state_dict(state_dict) + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) + actual_missing_keys = [] + for key in missing_keys: + if not any(substring in key for substring in substrings_to_ignore): + actual_missing_keys.append(key) + if strict: + if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: + raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") + return _IncompatibleKeys(actual_missing_keys, unexpected_keys) + + +# class FSDPLlama(Llama): +# def __init__( +# self, model: Transformer, tokenizer: DiscreteMultimodalTokenizer, config: ModelConfig, fsdp_checkpointer: Any +# ): +# self.fsdp_checkpointer = fsdp_checkpointer +# super().__init__(model, tokenizer, config) +# self.set_up_fsdp() + +# def set_up_fsdp(self): +# """ +# Set up FSDP for the model. +# """ + +# model = self.model +# # detach the model from the parent class +# self.model = None +# del self.model + +# # build FSDP sharding strategy and device_mesh +# strategy = { +# "full": ShardingStrategy.FULL_SHARD, +# "hybrid": ShardingStrategy.HYBRID_SHARD, +# "none": ShardingStrategy.NO_SHARD, +# }[self.config.fsdp["sharding_strategy"]] +# log.critical(f"Using {strategy} sharding strategy for FSDP") + +# if self.config.fsdp["sharding_strategy"] == "hybrid": +# sharding_group_size = self.config.fsdp["sharding_group_size"] +# device_mesh = hsdp_device_mesh( +# sharding_group_size=sharding_group_size, +# ) +# else: +# device_mesh = hsdp_device_mesh( +# sharding_group_size=distributed.get_world_size(), +# ) +# parallel_state.fsdp_device_mesh = device_mesh + +# if distributed.get_rank() == 0: +# # only load model in rank0 to reduce network traffic and sync later +# self.fsdp_checkpointer.load_model_during_init(model, is_ema=False) + +# if not hasattr(self, "fsdp_wrap_block_cls"): +# raise ValueError("Networks does not have fsdp_wrap_block_cls attribute, please check the net definition") +# fsdp_blocks_cls = self.fsdp_wrap_block_cls +# fsdp_blocks_cls = ( +# list(fsdp_blocks_cls) if isinstance(fsdp_blocks_cls, (list, tuple, set)) else [fsdp_blocks_cls] +# ) +# log.critical(f"Using FSDP blocks {fsdp_blocks_cls}") + +# log.critical(f"Using wrap policy {self.config.fsdp['policy']}") + +# if self.config.fsdp["policy"] == "size": +# # Size based policy won't work for transformers because the tokenizers need to be accessible at multiple +# # layers (input / output). This is handled by this sharding strategy. +# min_num_params = self.config.fsdp["min_num_params"] +# log.critical(f"Using {min_num_params} as the minimum number of parameters for auto-wrap policy") +# log.info("If using a Transformer model. Please use the transformer wrap policy.") +# wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) +# else: +# # Use the auto wrap policy for transformers +# wrap_policy = functools.partial( +# transformer_auto_wrap_policy, +# transformer_layer_cls=set(fsdp_blocks_cls), +# ) +# tensor_kwargs = {"device": "cuda", "dtype": model.precision} + +# # Wrap the model with FSDP and attach it back to this class +# self.model = FSDP( +# model.to(**tensor_kwargs), +# sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync +# sharding_strategy=strategy, +# auto_wrap_policy=wrap_policy, +# device_id=torch.cuda.current_device(), +# device_mesh=device_mesh, +# limit_all_gathers=True, +# use_orig_params=True, # Do not flatten the parameter structure. Useful for layer_dependent lrs, etc. +# ) + +# if self.config.act_ckpt_enabled: +# # Apply activation checkpointing +# apply_fsdp_checkpointing(self.model, list_block_cls=fsdp_blocks_cls) + +# # Clean up memory +# torch.cuda.empty_cache() + +# def state_dict(self) -> Dict: +# raise NotImplementedError("FSDPLlama does not support state_dict, use state_dict_model and FSDPCheckpointer") + +# @misc.timer("FSDP state_dict_model") +# def state_dict_model(self) -> Dict: +# """ +# Get the model state_dict for checkpoint saving in the FSDP mode. +# """ +# with FSDP.summon_full_params(self.model): +# pass +# with FSDP.state_dict_type( +# self.model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) +# ): +# model_state = self.model.state_dict() +# # No support for EMA yet. +# ema_model_state = None +# return { +# "model": model_state, +# "ema": ema_model_state, +# } + +# def load_state_dict(self, state_dict: Dict, strict: bool = True, assign: bool = False) -> None: +# raise NotImplementedError("FSDPLlama does not support load_state_dict, using FSDPCheckpointer") + +# def init_optimizer_scheduler( +# self, optimizer_config: LazyDict, scheduler_config: LazyDict +# ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: +# """ +# Initialize the optimizer and scheduler for FSDP model. + +# Args: +# optimizer_config (LazyDict): The optimizer configuration. +# scheduler_config (LazyDict): The scheduler configuration. + +# Returns: +# tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: The optimizer and scheduler. +# """ +# optimizer, scheduler = super().init_optimizer_scheduler(optimizer_config, scheduler_config) +# self.fsdp_checkpointer.load_optim_scheduler_during_init( +# self.model, +# optimizer, +# scheduler, +# ) +# return optimizer, scheduler + +# def get_ckpt_postfix(self) -> Tuple[str, int]: +# """Get the checkpoint file postfix. check FSDPCheckpointer for more details + +# Returns: +# postfix (str): The postfix of the checkpoint file. +# replicate_idx, shard_idx (int), current gpu replicate_idx, shard_idx in FSDP \ +# we will not save each ema model in each GPU, \ +# ema model with same rate will be saved once +# total_ema_num (int) +# """ +# replicate_idx, shard_idx = parallel_state.fsdp_device_mesh.get_coordinate() +# # !!! EMA is not supported +# if replicate_idx == 0: +# return "", 0, shard_idx, 0 +# return "", replicate_idx, shard_idx, 0 diff --git a/cosmos_predict1/autoregressive/training/modules/attention.py b/cosmos_predict1/autoregressive/training/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1b6c81cb10a34033da0ccf794991e89d2501df --- /dev/null +++ b/cosmos_predict1/autoregressive/training/modules/attention.py @@ -0,0 +1,734 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Optional, Tuple, Union + +import torch +from megatron.core import ModelParallelConfig, parallel_state +from torch import nn +from torch.distributed import _functional_collectives as funcol +from transformer_engine.pytorch.attention import _SplitAlongDim, apply_rotary_pos_emb, check_set_window_size +from transformer_engine.pytorch.constants import AttnBiasTypes +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.module.linear import Linear as LinearTE +from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE + +from cosmos_predict1.autoregressive.modules.embedding import RotaryPositionEmbedding +from cosmos_predict1.autoregressive.modules.linear import ColumnParallelLinear, RowParallelLinear +from cosmos_predict1.autoregressive.modules.normalization import create_norm +from cosmos_predict1.autoregressive.utils.parallel import AllReduceBWDRMSNormTE + + +class GQA(nn.Module): + """ + Grouped Query Attention (GQA) with KV cache (only supported for inference). + """ + + def __init__( + self, + n_heads: int, + n_kv_heads: Union[int, None], + dim: int, + max_batch_size: int, + max_seq_len: int, + context_dim: Optional[int] = None, + inference: bool = True, + flash_attn: bool = True, + use_qk_normalization: bool = False, + norm_type: str = "rmsnorm", + norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + set_parallel_mode: Optional[bool] = False, + model_parallel: Optional[ModelParallelConfig] = None, + attention_tp: Optional[bool] = False, + causal_mask: Optional[bool] = True, + head_dim: Optional[int] = None, + fuse_qkv: bool = False, + precision: str = "bfloat16", + attention_type: str = "self", + ): + """ + Initializes the GQA module. + + Args: + n_heads (int): The number of attention heads. + n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads. + dim (int): The dimensionality of the input and output. + max_batch_size (int): The maximum batch size. + max_seq_len (int): The maximum sequence length. + context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None. + inference (bool, optional): Whether the model is in inference mode. Defaults to True. + flash_attn (bool, optional): Whether to use Flash attention. Defaults to True. + use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False. + norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm". + norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5. + attention_dropout (float, optional): Dropout rate for attention. Defaults to 0.0. + tp_group (int, optional): The tensor parallel group. + set_parallel_mode (bool, optional): Whether to set parallel mode which enables parallel linear. Defaults to False. + model_parallel (ModelParallelConfig, optional): The Megatron model parallel configuration. + attention_tp (bool, optional): Whether to use tensor parallelism for attention layers. Defaults to False. + causal_mask (bool, optional): Whether to use causal mask. Defaults to True. + head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads. + fuse_qkv (bool, optional): Whether to fuse QKV projections. Defaults to False. + precision (str, optional): The precision of the model. Defaults to "bfloat16". + attention_type (str, optional): The type of attention. Defaults to "self". + """ + super().__init__() + assert attention_type in ["self", "cross", "full"], f"Invalid attention type: {attention_type}" + self.attention_type = attention_type + self.model_parallel = model_parallel + if self.model_parallel and self.model_parallel.tensor_model_parallel_size > 1 and attention_tp: + self.tp_size = self.model_parallel.tensor_model_parallel_size + else: + self.tp_size = 1 + + context_dim = dim if context_dim is None else context_dim + + self.dim = dim + self.context_dim = context_dim + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_kv_heads = self.n_kv_heads // self.tp_size + self.n_local_heads = n_heads // self.tp_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads if head_dim is None else head_dim + assert flash_attn, "Flash attention is required." + self.attention_dropout = attention_dropout + self.causal_mask = causal_mask + self.fuse_qkv = fuse_qkv + self.precision = precision + + if fuse_qkv: + assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})" + self.total_head_dim = (n_heads + 2 * self.n_kv_heads) * self.head_dim + self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim + + if set_parallel_mode and attention_tp and not inference: + kwargs = {"bias": False, "init_method": lambda x: x, "config": self.model_parallel} + # Using column and row parallel linear layers + if fuse_qkv: + self.wqkv = ColumnParallelLinear(dim, self.total_head_dim, **kwargs) + else: + self.wq = ColumnParallelLinear(dim, n_heads * self.head_dim, **kwargs) + self.wk = ColumnParallelLinear(context_dim, self.n_kv_heads * self.head_dim, **kwargs) + self.wv = ColumnParallelLinear(context_dim, self.n_kv_heads * self.head_dim, **kwargs) + + # Linear layer for output projection + self.wo = RowParallelLinear( + n_heads * self.head_dim, dim, input_is_parallel=True, skip_bias_add=True, **kwargs + ) + + else: + # Linear layers for query, key, and value projections + if fuse_qkv: + self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False) + else: + self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False) + self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False) + + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len + if inference and self.attention_type == "self": + # Cache for key and value tensors + self.init_kv_cache() + + # QK normalization layers + if use_qk_normalization: + assert n_heads % self.tp_size == 0, "n_heads must be divisible by tensor_model_parallel_size" + assert self.n_kv_heads % self.tp_size == 0, "n_kv_heads must be divisible by tensor_model_parallel_size" + self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) + self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + self.use_qk_normalization = use_qk_normalization + self.inference = inference + + if fuse_qkv: + # Register hook to load fused QKV weights + self._register_load_state_dict_pre_hook(self.load_hook) + + self.to(dtype=getattr(torch, self.precision)) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def init_kv_cache(self, dtype=None): + cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim) + if dtype is None: + dtype = getattr(torch, self.precision) + if self.attention_type == "self": + self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda() + self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda() + + def set_inference_flag(self, flag): + self.inference = flag + if flag and self.attention_type == "self": + if self.cache_k is None or self.cache_v is None: + self.init_kv_cache() + + def forward( + self, + x: torch.Tensor, + rope: RotaryPositionEmbedding, + input_pos: torch.Tensor, + mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ): + """ + Forward pass of GQA. + + Args: + x: The input tensor of shape (batch_size, seq_len, dim). + rope: The rotary positional embedding module. + input_pos: The starting position of the current sequence. + mask: The attention mask tensor. + context: The context tensor of shape (batch_size, context_len, dim). + + Returns: + The output tensor after applying GQA. + """ + bsz, seqlen, _ = x.shape + + # Use one single module to handle both self-attn and cross-attn + context = x if context is None else context + context_len = seqlen if context is None else context.shape[1] + + if self.fuse_qkv: + q_size = self.n_local_heads * self.head_dim + kv_size = self.n_local_kv_heads * self.head_dim + xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1) + else: + # Compute query, key, and value projections + xq = self.wq(x) + xk, xv = self.wk(context), self.wv(context) + + # Reshape projections + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) + + # QK normalization + if self.use_qk_normalization: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + # Apply rotary positional embeddings to queries and keys + # Only apply RoPE to self-attention! + if self.attention_type in ["self", "full"]: + xq, xk = rope(xq, xk, input_pos, seqlen) + + xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) + # xq: (bs, n_local_heads, seqlen, head_dim) + # xk: (bs, n_kv_heads, cache_len + context_len, head_dim) + # xv: (bs, n_kv_heads, cache_len + context_len, head_dim) + if self.inference and self.attention_type == "self": + # Update cache with current key and value tensors + assert input_pos is not None + self.cache_k[:bsz, :, input_pos] = xk + self.cache_v[:bsz, :, input_pos] = xv + keys, values = ( + self.cache_k[:bsz, :, :], + self.cache_v[:bsz, :, :], + ) + else: + keys, values = xk, xv + + # Repeat keys and values if necessary + keys = keys.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim) + values = values.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim) + + if self.attention_type == "self" and self.causal_mask: + # During inference, `is_causal` should be set to False when KV cache is pre-computed and used, + # since the masking is handled outside this attention module. + # During training, `is_causal` should be set to None to use the default behavior of FlashAttention. + is_causal = False if self.inference else None + else: + # This is used for full-attention transformer (e.g., ViT) + # also for the cross-attn, it's always full-attn w/o causal + is_causal = False + output = scaled_dot_product_attention( + xq, + keys, + values, + head_dim=self.head_dim, + mask=mask, + is_causal=is_causal, + dropout_p=self.attention_dropout if self.training else 0.0, + ) + output = output.view(bsz, seqlen, -1) + output = self.wo(output) + + if self.inference and self.tp_size > 1: + output = funcol.all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) + return output + + def init_weights(self, init_std: float): + """ + Initializes the weights of all modules. + """ + if self.fuse_qkv: + nn.init.trunc_normal_(self.wqkv.weight, mean=0.0, std=0.02) + else: + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + if self.use_qk_normalization: + torch.nn.init.ones_(self.q_norm.weight) + torch.nn.init.ones_(self.k_norm.weight) + + +def scaled_dot_product_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + head_dim: int, + mask: Optional[torch.Tensor] = None, + is_causal: Optional[bool] = None, + dropout_p: float = 0.0, +) -> torch.Tensor: + """ + PyTorch's native implementation of Flash Attention 2. + + If `is_causal` is given, then the causal attention mask is applied accordingly: + - If `is_causal` is True, the standard upper-left causal attention masking is applied. + - If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is + provided (i.e., `mask is not None`). + + If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied + based on the provided mask tensor: + - If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True, + leading to the standard upper-left causal attention masking. + - If an attention mask is given (i.e., `mask is not None`), the provided mask is used, + and `is_causal` is set to False. + + Args: + q (torch.Tensor): Query tensor + k (torch.Tensor): Key tensor + v (torch.Tensor): Value tensor + head_dim (int): Dimension of each attention head + mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None. + is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None. + dropout_p (float, optional): Dropout rate. Defaults to 0.0. + + Returns: + torch.Tensor: Output tensor after applying scaled dot-product attention + """ + scale = 1.0 / math.sqrt(head_dim) + if is_causal is None: + is_causal = mask is None + y = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=dropout_p, + scale=scale, + is_causal=is_causal, + ) + return y.transpose(1, 2).contiguous() + + +def enable_different_context_dim_in_te_ca( + te_mha_module, + context_dim, + args, +): + """ + Hijacks the MultiheadAttention (MHA) module from TransformerEngine (TE) to use a different context-dim for KV calculation. + """ + self = te_mha_module + + common_gemm_kwargs = { + "fuse_wgrad_accumulation": args["fuse_wgrad_accumulation"], + "tp_group": self.tp_group, + "tp_size": self.tp_size, + "get_rng_state_tracker": self.get_rng_state_tracker, + "sequence_parallel": self.sequence_parallel, + "params_dtype": self.params_dtype, + } + + self.key_value = LinearTE( + context_dim, + 2 * self.hidden_size_kv, + init_method=None, + bias=args["bias"], + return_bias=False, + parallel_mode="column" if args["set_parallel_mode"] else None, + parameters_split=("key", "value") if not args["fuse_qkv_params"] else None, + **common_gemm_kwargs, + ) + + +def enable_qk_normalization_in_te_mha( + te_mha_module, + norm_eps: float, + is_self_attn: bool = True, +): + """ + Hijacks the MultiheadAttention (MHA) module from TransformerEngine (TE) to use our `te_mha_forward_with_qk_norm`. + The `te_mha_forward_with_qk_norm` function is just a copy of the TE MHA's forward function (source code at + https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) with the addition + of several lines of code for the QK normalization operations. + """ + self = te_mha_module + + # First, we add the QK norm layers (RMSNorm class) to the TE's MHA module in advance for our custom forward function. + if is_self_attn: + common_kwargs = dict( + eps=norm_eps, + device=self.layernorm_qkv.layer_norm_weight.device, + sequence_parallel=self.layernorm_qkv.sequence_parallel, + params_dtype=self.layernorm_qkv.layer_norm_weight.dtype, + zero_centered_gamma=self.layernorm_qkv.zero_centered_gamma, + ) + else: + common_kwargs = dict( + eps=norm_eps, + device=self.layernorm_query.query_weight.device, + sequence_parallel=self.layernorm_query.sequence_parallel, + params_dtype=self.layernorm_query.query_weight.dtype, + zero_centered_gamma=self.layernorm_query.zero_centered_gamma, + ) + if parallel_state.model_parallel_is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + self.q_norm = AllReduceBWDRMSNormTE( + self.hidden_size_per_attention_head, process_group=tp_group, **common_kwargs + ) + self.k_norm = AllReduceBWDRMSNormTE( + self.hidden_size_per_attention_head, process_group=tp_group, **common_kwargs + ) + else: + self.q_norm = RMSNormTE(self.hidden_size_per_attention_head, **common_kwargs) + self.k_norm = RMSNormTE(self.hidden_size_per_attention_head, **common_kwargs) + + # Second, we define the custom forward function for the TE's MHA module, with the QK normalization operations. + def te_mha_forward_with_qk_norm( + hidden_states: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + encoder_output: Optional[torch.Tensor] = None, + attn_mask_type: Optional[str] = None, + window_size: Optional[Tuple[int, int]] = None, + is_first_microbatch: Optional[bool] = None, + checkpoint_core_attention: bool = False, + inference_params: Optional[Any] = None, + rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + fast_zero_fill: bool = True, + ) -> Tuple[Union[torch.Tensor, None], ...]: + """ + Forward propagation for MultiheadAttention layer. + + """ + # hidden_states: [sq, b, h] + + if attn_mask_type is None: + attn_mask_type = self.attn_mask_type + if window_size is None: + window_size = self.window_size + window_size = check_set_window_size(attn_mask_type, window_size) + + if "padding" in attn_mask_type and attention_mask is not None: + for mask in attention_mask: + assert mask.dtype == torch.bool, "Attention mask must be in boolean type!" + + assert ( + core_attention_bias_type in AttnBiasTypes + ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" + + # ================================================= + # Pre-allocate memory for key-values for inference + # ================================================= + + if inference_params and self.layer_number is not None: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_length + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory(inf_max_seq_len, inf_max_batch_size, hidden_states.dtype) + inference_value_memory = self._allocate_memory(inf_max_seq_len, inf_max_batch_size, hidden_states.dtype) + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, + inference_value_memory, + ) + else: + ( + inference_key_memory, + inference_value_memory, + ) = inference_params.key_value_memory_dict[self.layer_number] + + # ====================== + # Query, Key, and Value + # ====================== + + # fp8_mha = FP8GlobalStateManager.is_fp8_enabled() and FP8GlobalStateManager.get_fp8_recipe().fp8_mha + # fp8_kwargs = {"fp8_output": fp8_mha and rotary_pos_emb is None} + fp8_kwargs = {} + + layernorm_output = None + if self.attention_type == "self": + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] + layernorm_qkv_outputs = self.layernorm_qkv( + hidden_states, is_first_microbatch=is_first_microbatch, **fp8_kwargs + ) + mixed_x_layer = layernorm_qkv_outputs + + num_queries_per_key_value = self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition + # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, (np/ng + 2), ng, hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + (num_queries_per_key_value + 2), + self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head, + ) + # split along third last dimension + split_dim = -3 + + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, (np/ng + 2), ng, hn] + # --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn] + query_layer, key_layer, value_layer = _SplitAlongDim.apply( + mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1) + ) + # query: -> [sq, b, np, hn] + # key, value: -> [sq, b, ng, hn] + query_layer, key_layer, value_layer = ( + x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) + for x in (query_layer, key_layer, value_layer) + ) + elif self.attention_type == "cross": + # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] + mixed_kv_layer = self.key_value(encoder_output, is_first_microbatch=is_first_microbatch, **fp8_kwargs) + + # [sq, b, (ng * 2 * hn)] --> [sq, b, 2 * ng, hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + ( + 2 * self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head, + ) + # split along second last dimension + split_dim = -2 + + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # mixed_kv_layer --> 2 [sk, b, ng, hn] + key_layer, value_layer = _SplitAlongDim.apply( + mixed_kv_layer, + split_dim, + mixed_kv_layer.shape[split_dim] // 2, + ) + key_layer, value_layer = ( + x.reshape( + x.size(0), + x.size(1), + -1, + self.hidden_size_per_attention_head, + ) + for x in (key_layer, value_layer) + ) + + # Attention head [sq, b, h] --> [sq, b, hp] + layernorm_query_outputs = self.layernorm_query( + hidden_states, is_first_microbatch=is_first_microbatch, **fp8_kwargs + ) + query_layer = layernorm_query_outputs + + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + query_layer = query_layer.view(*new_tensor_shape) + + # ====================================================== + # Apply QK normalization (RMSNorm) + # ====================================================== + + # Must use torch.reshape to flatten the tensor, otherwise an error will be triggered in TE's RMSNorm module. + query_layer = self.q_norm(query_layer.reshape(-1, self.hidden_size_per_attention_head)).view(query_layer.shape) + key_layer = self.k_norm(key_layer.reshape(-1, self.hidden_size_per_attention_head)).view(key_layer.shape) + + # ====================================================== + # Apply relative positional encoding (rotary embedding) + # ====================================================== + + if rotary_pos_emb is not None: + assert not isinstance(query_layer, Float8Tensor) and not isinstance( + key_layer, Float8Tensor + ), "RoPE is not supported for Float8Tensors!" + # duplicate the pos_emb for self attention + if not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + q_pos_emb, k_pos_emb = rotary_pos_emb + + # adjust key and value for inference + if inference_params is not None: + if self.qkv_format == "sbhd": + sequence_length = key_layer.size(0) + elif self.qkv_format == "bshd": + sequence_length = key_layer.size(1) + else: + raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.") + + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + sequence_length + + q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] + k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) + + # =========================== + # Core attention computation + # =========================== + context_layer = self.core_attention( + query_layer, + key_layer, + value_layer, + qkv_format=self.qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + attention_mask=attention_mask, + attn_mask_type=attn_mask_type, + window_size=window_size, + checkpoint_core_attention=checkpoint_core_attention, + core_attention_bias_type=core_attention_bias_type, + core_attention_bias=core_attention_bias, + alibi_slopes=alibi_slopes, + fast_zero_fill=fast_zero_fill, + inference_params=inference_params, + ) + + # =================== + # Output. [sq, b, h] + # =================== + + projection_output = self.proj( + context_layer, + is_first_microbatch=is_first_microbatch, + ) + + if self.return_bias: + attention_output, attention_bias = projection_output + else: + attention_output, attention_bias = projection_output, None + + outputs = (attention_output,) + if self.return_bias: + outputs += (attention_bias,) + if self.input_layernorm and self.return_layernorm_output: + outputs += (layernorm_output,) + return outputs if len(outputs) > 1 else outputs[0] + + # Finally, we replace the forward method of given TE's MHA module with our custom forward function. + self.forward = te_mha_forward_with_qk_norm + + +def create_group_causal_attn_mask( + num_temporal_groups: int, num_query_per_group: int, num_key_per_group: int, mode: str = "causal" +) -> torch.Tensor: + """ + Creates a group-based attention mask for scaled dot-product attention with two modes: + 'causal' and 'group_diagonal'. + + Parameters: + - num_temporal_groups (int): The number of temporal groups (e.g., frames in a video sequence). + - num_query_per_group (int): The number of query tokens per temporal group. (e.g., latent tokens in a frame, H x W). + - num_key_per_group (int): The number of key tokens per temporal group. (e.g., action tokens per frame). + - mode (str): The mode of the attention mask. Options are: + - 'causal': Query tokens can attend to key tokens from the same or previous temporal groups. + - 'group_diagonal': Query tokens can attend only to key tokens from the same temporal group. + + Returns: + - attn_mask (torch.Tensor): A boolean tensor of shape (L, S), where: + - L = num_temporal_groups * num_query_per_group (total number of query tokens) + - S = num_temporal_groups * num_key_per_group (total number of key tokens) + The mask indicates where attention is allowed (True) and disallowed (False). + + Example: + Input: + num_temporal_groups = 3 + num_query_per_group = 4 + num_key_per_group = 2 + Output: + Causal Mask Shape: torch.Size([12, 6]) + Group Diagonal Mask Shape: torch.Size([12, 6]) + if mode='causal': + tensor([[ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [ True, True, True, True, False, False], + [ True, True, True, True, False, False], + [ True, True, True, True, False, False], + [ True, True, True, True, False, False], + [ True, True, True, True, True, True], + [ True, True, True, True, True, True], + [ True, True, True, True, True, True], + [ True, True, True, True, True, True]]) + + if mode='group_diagonal': + tensor([[ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [ True, True, False, False, False, False], + [False, False, True, True, False, False], + [False, False, True, True, False, False], + [False, False, True, True, False, False], + [False, False, True, True, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + [False, False, False, False, True, True]]) + + """ + assert mode in ["causal", "group_diagonal"], f"Mode {mode} must be 'causal' or 'group_diagonal'" + + # Total number of query and key tokens + total_num_query_tokens = num_temporal_groups * num_query_per_group # Total number of query tokens (L) + total_num_key_tokens = num_temporal_groups * num_key_per_group # Total number of key tokens (S) + + # Generate time indices for query and key tokens (shape: [L] and [S]) + query_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_query_per_group) # Shape: [L] + key_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_key_per_group) # Shape: [S] + + # Expand dimensions to compute outer comparison + query_time_indices = query_time_indices.unsqueeze(1) # Shape: [L, 1] + key_time_indices = key_time_indices.unsqueeze(0) # Shape: [1, S] + + if mode == "causal": + # Causal Mode: Query can attend to keys where key_time <= query_time + attn_mask = query_time_indices >= key_time_indices # Shape: [L, S] + elif mode == "group_diagonal": + # Group Diagonal Mode: Query can attend only to keys where key_time == query_time + attn_mask = query_time_indices == key_time_indices # Shape: [L, S] + + assert attn_mask.shape == (total_num_query_tokens, total_num_key_tokens), "Attention mask shape mismatch" + return attn_mask diff --git a/cosmos_predict1/autoregressive/training/networks/transformer.py b/cosmos_predict1/autoregressive/training/networks/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a69228fe6c46f019d0abf2ecb9371c4bd5f57 --- /dev/null +++ b/cosmos_predict1/autoregressive/training/networks/transformer.py @@ -0,0 +1,1295 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +import transformer_engine as te +from megatron.core import InferenceParams, ModelParallelConfig, parallel_state +from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region +from torch.distributed import ProcessGroup +from torch.distributed import _functional_collectives as funcol +from torch.distributed import broadcast, get_process_group_ranks +from torch.nn.modules.module import _IncompatibleKeys +from transformer_engine.pytorch.module.linear import Linear as LinearTE +from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE + +from cosmos_predict1.utils import log + +_ACTION_DIM = 8 +from cosmos_predict1.autoregressive.modules.embedding import ( + RotaryPositionEmbeddingPytorch, + RotaryPositionEmbeddingPytorchV2, + RotaryPositionEmbeddingTE, + SinCosPosEmbAxisTE, + get_pos_emb_on_this_cp_rank, + get_pos_emb_on_this_sptp_rank, +) +from cosmos_predict1.autoregressive.modules.linear import ColumnParallelLinear, TrainingVocabParallelEmbedding +from cosmos_predict1.autoregressive.modules.mlp import TrainingMLP, compute_llama3_ffn_hidden_dim +from cosmos_predict1.autoregressive.modules.normalization import create_norm +from cosmos_predict1.autoregressive.training.modules.attention import ( + GQA, + create_group_causal_attn_mask, + enable_different_context_dim_in_te_ca, + enable_qk_normalization_in_te_mha, +) +from cosmos_predict1.autoregressive.utils.checkpoint import process_state_dict, substrings_to_ignore +from cosmos_predict1.autoregressive.utils.misc import maybe_convert_to_namespace +from cosmos_predict1.autoregressive.utils.parallel import ( + AllReduceBWDRMSNormTE, + allreduce_layernorm_grads, + sync_1d_parameters, +) + +_MLP_HIDDEN_DIM_DIVISOR = ( + 4 # hidden dim of the action embedding layer is action_embedding_dim // _MLP_HIDDEN_DIM_DIVISOR +) + +_T5_NUM_TOKENS = 512 + + +class TransformerBlock(nn.Module): + """ + A single transformer block consisting of an attention layer and a feed-forward layer. + """ + + def __init__(self, layer_id: int, model_parallel: Optional[ModelParallelConfig] = None, args=None): + """ + Initializes the TransformerBlock module. + + Args: + layer_id: The ID of the transformer block. + args: The model arguments containing hyperparameters. + """ + super().__init__() + args = maybe_convert_to_namespace(args) + attention_args = { + "n_heads": args["n_heads"], + "n_kv_heads": args["n_kv_heads"], + "dim": args["dim"], + "context_dim": None, + "max_batch_size": args["max_batch_size"], + "max_seq_len": args["max_seq_len"], + "inference": args["inference"], + "flash_attn": args["flash_attn"], + "use_qk_normalization": args["use_qk_normalization"], + "attention_dropout": getattr(args, "attention_dropout", 0.0), + "set_parallel_mode": args["set_parallel_mode"], + "model_parallel": model_parallel, + "attention_tp": args["attention_tp"], + "causal_mask": args["causal_mask"], + "head_dim": args["head_dim"], + "fuse_qkv": getattr(args, "fuse_qkv", False), + "precision": getattr(args, "precision", "bfloat16"), + "attention_type": getattr(args, "attention_type", "self"), + } + self.attention = GQA(**attention_args) + + self.has_cross_attention = False + self.cross_attention, self.cross_attention_norm = None, None + + if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0: + self.has_cross_attention = True + cross_attention_args = attention_args.copy() + cross_attention_args.update( + {"context_dim": args["context_dim"], "fuse_qkv": False, "attention_type": "cross"} + ) + self.cross_attention = GQA(**cross_attention_args) + self.cross_attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + + self.feed_forward = TrainingMLP( + dim=args["dim"], + hidden_dim=( + compute_llama3_ffn_hidden_dim( + dim=args["dim"], multiple_of=args["multiple_of"], ffn_dim_multiplier=args["ffn_dim_multiplier"] + ) + if args["ffn_hidden_size"] is None + else args["ffn_hidden_size"] + ), + hidden_dropout=getattr(args, "hidden_dropout", 0.0), + set_parallel_mode=args["set_parallel_mode"], + model_parallel=model_parallel, + inference=args["inference"], + ) + self.layer_id = layer_id + self.num_layers = args["n_layers"] + self.attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + self.ffn_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) + + # If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the + # total number of transformer blocks. Default is `True` (following the TorchTitan implementation of Llama3). + if getattr(args, "depth_init", True): + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + rope: RotaryPositionEmbeddingPytorch, + input_pos: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the TransformerBlock module. + + Args: + x: The input tensor. + input_pos: The position of the current sequence. Used in inference (with KV cache) only. + freqs_cis: The precomputed frequency values for rotary position embeddings. + mask: The attention mask tensor. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + The output tensor after applying the transformer block. + """ + # Apply attention and residual connection + h = x + self.attention(self.attention_norm(x), rope=rope, input_pos=input_pos, mask=mask) + + # If insert cross-attention, apply CA and residual connection + if self.has_cross_attention: + h = h + self.cross_attention( + self.cross_attention_norm(h), rope=rope, input_pos=input_pos, mask=context_mask, context=context + ) + + # Apply feed-forward network and residual connection + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + """ + Initializes the weights of the transformer block. + """ + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + if self.has_cross_attention: + self.cross_attention_norm.reset_parameters() + self.cross_attention.init_weights(self.weight_init_std) + # zero-init the final output layer of cross-attention + # nn.init.zeros_(self.cross_attention.wo.weight) + + +class TransformerBlockTE(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. + + Args: + layer_id (int): The ID of the transformer block. + args: The model arguments containing hyperparameters. + """ + + def __init__( + self, + layer_id: int, + args, + tp_group: Optional[ProcessGroup] = None, + set_parallel_mode: bool = False, + attn_input_format: str = "bshd", + ): + attention_args = { + "hidden_size": args["dim"], + "ffn_hidden_size": ( + compute_llama3_ffn_hidden_dim( + dim=args["dim"], multiple_of=args["multiple_of"], ffn_dim_multiplier=args["ffn_dim_multiplier"] + ) + if args["ffn_hidden_size"] is None + else args["ffn_hidden_size"] + ), + "num_attention_heads": args["n_heads"], + "bias": False, + "layernorm_epsilon": args["norm_eps"], + "hidden_dropout": getattr(args, "hidden_dropout", 0.0), + "attention_dropout": getattr(args, "attention_dropout", 0.0), + "normalization": "RMSNorm", + "activation": "swiglu", + "attn_input_format": attn_input_format, + "num_gqa_groups": args["n_kv_heads"], + "fuse_wgrad_accumulation": False, + "fuse_qkv_params": False, + "tp_group": tp_group, + "sequence_parallel": args["sequence_parallel"], + "set_parallel_mode": set_parallel_mode, + "layer_number": layer_id + 1, + "self_attn_mask_type": "causal" if args["causal_mask"] else "no_mask", + "kv_channels": args["head_dim"], # If None, te.pytorch.TransformerLayer defaults it to dim // n_heads + "layer_type": "encoder", + } + self.has_cross_attention = False + if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0: + self.has_cross_attention = True + attention_args["layer_type"] = "decoder" + super().__init__(**attention_args) + if args["use_qk_normalization"]: + # Add QK normalization layers and replace the forward function of original Multi-Head Attention module with + # our custom one to add QK normalization operations. + enable_qk_normalization_in_te_mha(self.self_attention, norm_eps=args["norm_eps"], is_self_attn=True) + + if self.has_cross_attention: + enable_qk_normalization_in_te_mha(self.inter_attention, norm_eps=args["norm_eps"], is_self_attn=False) + + if self.has_cross_attention: + enable_different_context_dim_in_te_ca( + self.inter_attention, context_dim=args["context_dim"], args=attention_args + ) + + self.layer_id = layer_id + self.num_layers = args["n_layers"] + # If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the + # total number of transformer blocks. Default is `True` (following the TorchTitan implementation of Llama3). + if getattr(args, "depth_init", True): + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + self.args = args + self.inference = args["inference"] + + def set_inference_flag(self, flag: bool): + """ + Set the inference flag for the transformer layers. + """ + self.inference = flag + + def forward( + self, + x: torch.Tensor, + rotary_pos_emb: torch.Tensor, + mask: Optional[torch.Tensor], + inference_params: Optional[InferenceParams] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Custom forward to make sure we only pass relevant arguments to the + forward pass of the `TransformerLayer`. + + Args: + x (torch.Tensor): The input tensor. + mask (Optional[torch.Tensor]): The attention mask tensor. + inference_params (Optional[InferenceParams]): Inference parameters used for caching key-value pairs in the TE backend. + It is not applicable for the PyTorch backend and should be set to None in that case. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + torch.Tensor: The output tensor after applying the transformer block + """ + + inference_params = None if not self.inference else inference_params + output = super().forward( + x, + attention_mask=mask, + rotary_pos_emb=rotary_pos_emb.to(x.device), + inference_params=inference_params, + encoder_output=context, + enc_dec_attn_mask=context_mask, + ) + return output + + def init_weights(self): + """ + Initializes the weights of the transformer block. + """ + # Self Attention + attn_layer = self.self_attention.layernorm_qkv + for linear_weight in [attn_layer.query_weight, attn_layer.key_weight, attn_layer.value_weight]: + nn.init.trunc_normal_(linear_weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.self_attention.proj.weight, mean=0.0, std=self.weight_init_std) + + # Cross Attention + if self.has_cross_attention: + nn.init.trunc_normal_(self.inter_attention.layernorm_query.query_weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.inter_attention.key_value.key_weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.inter_attention.key_value.value_weight, mean=0.0, std=0.02) + # zero-init the final output layer of cross-attention + if self.args["zero_init_cross_attn_proj"]: + nn.init.zeros_(self.inter_attention.proj.weight) + else: + nn.init.trunc_normal_(self.inter_attention.proj.weight, mean=0.0, std=self.weight_init_std) + + # RMS Normalization + for norm_weight in (self.layernorm_mlp.layer_norm_weight, self.self_attention.layernorm_qkv.layer_norm_weight): + torch.nn.init.ones_(norm_weight) + + # In the case of QK Normalization, we also reset the parameters of the QK normalization layers. + if self.args["use_qk_normalization"]: + for norm_weight in [self.self_attention.q_norm.weight, self.self_attention.k_norm.weight]: + torch.nn.init.ones_(norm_weight) + + # MLP + for linear_weight in (self.layernorm_mlp.fc1_weight, self.layernorm_mlp.fc2_weight): + nn.init.trunc_normal_(linear_weight, mean=0.0, std=self.weight_init_std) + # The fc1_weight is a fused weight of w1 and w2 in the MLP of the PyTorch backend, where w1 is initialized with + # a different std (0.02 by TorchTitan). So we re-initialize the w1 part of the fused weight below. + split_point = self.layernorm_mlp.fc1_weight.shape[0] // 2 + nn.init.trunc_normal_(self.layernorm_mlp.fc1_weight[:split_point], mean=0.0, std=0.02) + + +class Transformer(nn.Module): + """ + The Transformer network consisting of transformer blocks. + """ + + def __init__(self, params, model_parallel=None, tokenizer_config=None, init_weights: bool = True): + """ + Initializes the Transformer module. + + Args: + params: The model parameters containing hyperparameters. + model_parallel: The model parallel configuration. + tokenizer_config: The model tokenizer configuration. + init_weights (bool): Whether to initialize the weights of the transformer following + TorchTitan's Llama3 initialization scheme. + """ + super().__init__() + # Check if self.params is an OmegaConf DictConfig instance + self.params = maybe_convert_to_namespace(params) + self.vocab_size = params["vocab_size"] + self.n_layers = params["n_layers"] + self.precision = getattr(torch, params["precision"]) + self.inference = params["inference"] + self.backend = params["backend"] + self.tokenizer_config = tokenizer_config + self.model_parallel = model_parallel + self.num_video_frames = params["num_video_frames"] + + self.token_emb_dropout = nn.Dropout(getattr(params, "embedding_dropout", 0.0)) + + tp_group = self._get_tp_group() + + # Sequence parallelism requires the first dimension to be the sequence dimension. When sequence parallelism + # is enabled, we transpose the first two dimensions of the input tensor, and specify the format as "sbhd", + # (sequence, batch, head, dim). Otherwise, the input format is "bshd" (batch, sequence, head, dim). + self.attn_input_format = "bshd" if not params["sequence_parallel"] else "sbhd" + + # Token embeddings + self.tok_embeddings = self._create_token_embeddings(self.model_parallel) + self.rope_config = self._create_rope_config() + + if self.backend == "pytorch": + self._initialize_pytorch_backend(model_parallel) + elif self.backend == "transformer_engine": + self._initialize_transformer_engine_backend(tp_group) + else: + raise ValueError(f"Unknown backend: {self.backend}") + + self.output = self._create_output_projection(model_parallel) + + # Action conditioning + self.use_action_condition = getattr(params, "use_action_condition", False) + if self.use_action_condition: + self.action_dim = getattr( + params, "action_dim", _ACTION_DIM + ) # e.g., [Δx, Δy, Δz, rx, ry, rz, gripper_open, zero_pad] + self.action_embedding_dim = self.params["action_embedding_dim"] # 1024 + self.action_embedding_mode = getattr(params, "action_embedding_mode", "mlp") # Default to mlp mode + self.group_causal_mask_mode = getattr( + params, "group_causal_mask_mode", None + ) # Default to None, 'causal' or 'group_diagonal' + self.action_embedding_layers = self._create_action_projection() + + if params["sequence_parallel"]: + if model_parallel is None: + setattr(params, "sequence_parallel", False) + log.critical("model_parallel is None. Disabling sequence parallelism.") + self.sequence_parallel_enabled = False + else: + assert self.backend == "transformer_engine", f"Invalid backend: {self.backend} for sequence parallelism" + assert ( + params["tensor_model_parallel_size"] > 1 + ), f"Invalid tensor_model_parallel_size: {params['tensor_model_parallel_size']}" + self.sequence_parallel_enabled = True + else: + self.sequence_parallel_enabled = False + + if init_weights: + self.init_weights() + + # Set default value for peft_last_n_layers and peft_every_n_layers + self.peft_last_n_layers = getattr(params, "peft_last_n_layers", 0) + self.peft_every_n_layers = getattr(params, "peft_every_n_layers", 0) + if self.peft_last_n_layers > 0 or self.peft_every_n_layers > 0: + self._setup_peft() + + # Freeze network parameters for finetuning w/ cross-attention + self.has_cross_attention = getattr(params, "insert_cross_attn", False) + if self.has_cross_attention: + self.ca_every_k_layers = getattr(params, "insert_cross_attn_every_k_layers", 1) + self.finetune_layers_with_cross_attn = getattr(params, "finetune_layers_with_cross_attn", False) + self.finetune_layers_without_cross_attn = getattr(params, "finetune_layers_without_cross_attn", False) + self._setup_cross_attn_ft() + + if self.params["apply_abs_pos_emb"]: + self.pos_emb_config = self._create_abs_pos_emb_config() + self.pos_emb, self.abs_pos_emb = self._initialize_abs_pos_emb() + if self.attn_input_format == "sbhd": + self.abs_pos_emb = self.abs_pos_emb.transpose(0, 1).contiguous() + self._broadcast_pos_emb(self.abs_pos_emb, tp_group) + + def _initialize_pytorch_backend(self, model_parallel): + self.layers = nn.ModuleList( + [ + TransformerBlock(layer_id, model_parallel, self.params).to(self.precision) + for layer_id in range(self.n_layers) + ] + ) + self.norm = create_norm(self.params["norm_type"], dim=self.params["dim"], eps=self.params["norm_eps"]).to( + self.precision + ) + pytorch_rope_version = getattr(self.params, "pytorch_rope_version", "v2") + if pytorch_rope_version == "v1": + self.rope = RotaryPositionEmbeddingPytorch(**self.rope_config) + elif pytorch_rope_version == "v2": + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + self.rope = RotaryPositionEmbeddingPytorchV2( + seq_len=self.params["max_seq_len"], training_type=training_type, **self.rope_config + ) + self._broadcast_pos_emb(self.rope.cos_cached, tp_group=self._get_tp_group()) + self._broadcast_pos_emb(self.rope.sin_cached, tp_group=self._get_tp_group()) + else: + raise ValueError(f"Unknown pytorch_rope_version: {pytorch_rope_version}") + + self.causal_mask = torch.tril( + torch.ones(self.params["max_seq_len"], self.params["max_seq_len"], dtype=torch.bool) + ).cuda() + + def _initialize_transformer_engine_backend(self, tp_group): + self.layers = self._create_transformer_layers(tp_group) + if self.params["sequence_parallel"]: + tp_group = parallel_state.get_tensor_model_parallel_group() + self.norm = AllReduceBWDRMSNormTE( + self.params["dim"], + process_group=tp_group, + eps=self.params["norm_eps"], + sequence_parallel=True, + ).to(self.precision) + else: + self.norm = RMSNormTE(self.params["dim"], eps=self.params["norm_eps"]).to(self.precision) + self.rope, self.rotary_pos_emb = self._initialize_rope() + self._broadcast_pos_emb(self.rotary_pos_emb, tp_group) + + def _create_rope_config(self) -> Dict: + shape_map = { + "3D": self.params["video_latent_shape"], + "2D": self.params["image_latent_shape"], + "1D": None, + } + latent_shape = shape_map.get(self.params["rope_dim"], None) + head_dim = self.params["head_dim"] + if head_dim is None: + head_dim = self.params["dim"] // self.params["n_heads"] + return { + "dim": head_dim, + "max_position_embeddings": self.params["max_seq_len"], + "original_max_position_embeddings": self.params["original_seq_len"], + "rope_theta": self.params["rope_theta"], + "apply_yarn": self.params["apply_yarn"], + "scale": self.params["yarn_scale"], + "beta_fast": self.params["yarn_beta_fast"], + "beta_slow": self.params["yarn_beta_slow"], + "rope_dim": self.params["rope_dim"], + "latent_shape": latent_shape, + "original_latent_shape": self.params["original_latent_shape"], + "pad_to_multiple_of": self.params["pad_to_multiple_of"], + } + + def _create_abs_pos_emb_config(self): + shape_map = { + "3D": self.params["video_latent_shape"], + "2D": self.params["image_latent_shape"], + "1D": None, + } + latent_shape = shape_map.get(self.params["rope_dim"], None) + return { + "dim": self.params["dim"], + "latent_shape": latent_shape, + "pad_to_multiple_of": self.params["pad_to_multiple_of"], + } + + def _create_token_embeddings(self, model_parallel=None, vocab_size: int = None): + """ + Create token embeddings. + + Args: + model_parallel: The model parallel configuration. + + Returns: + nn.Module: Token embeddings module. + """ + if vocab_size is None: + vocab_size = self.params["vocab_size"] + tp_size = self.params["tensor_model_parallel_size"] + if tp_size > 1: + # For inference in the PyTorch backend, we use PyTorch's allreduce (tracable) in the forward pass to enable torch.compile. + use_inference_allreduce = self.inference and self.params["backend"] == "pytorch" + emb = TrainingVocabParallelEmbedding( + vocab_size, + self.params["dim"], + init_method=lambda x: x, + config=model_parallel, + sequence_parallel=self.params["sequence_parallel"], + batch_first=not self.params["sequence_parallel"], + use_inference_allreduce=use_inference_allreduce, + ).to(self.precision) + return emb + else: + return nn.Embedding(vocab_size, self.params["dim"]).to(self.precision) + + def _create_action_projection(self): + """ + Create the action projection layer. + + Returns: + nn.Module: Action projection layer. + """ + assert self.action_embedding_mode == "mlp", f"Invalid action embedding mode: {self.action_embedding_mode}" + + # This method is not working well. (option 1. default) exp102e + hidden_dim = self.action_embedding_dim // _MLP_HIDDEN_DIM_DIVISOR + action_embedding_layers = nn.Sequential( + nn.Linear(self.action_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, self.action_embedding_dim), + ) + + return action_embedding_layers + + def _get_tp_group( + self, + ): + """ + Get tensor parallel process group if applicable. + + Returns: + torch.distributed.ProcessGroup or None: Tensor parallel process group if tensor parallelism is enabled, else None. + """ + if self.params["tensor_model_parallel_size"] > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + log.info(f"Using tensor model parallel group: {tp_group}") + return tp_group + + return None + + def _create_transformer_layers(self, tp_group): + """ + Create the transformer layers. + + Args: + tp_group (torch.distributed.ProcessGroup or None): Tensor parallel process group. + + Returns: + nn.ModuleList: List of transformer layers. + """ + return nn.ModuleList( + [ + TransformerBlockTE( + layer_id, + self.params, + tp_group, + set_parallel_mode=self.params["set_parallel_mode"], + attn_input_format=self.attn_input_format, + ).to(self.precision) + for layer_id in range(self.params["n_layers"]) + ] + ) + + def _create_output_projection(self, model_parallel=None, vocab_size: int = None): + """ + Create the output projection layer. + + Args: + model_parallel: The model parallel configuration. + vocab_size (int): Vocabulary size (to override the default vocab size). + Returns: + LinearTE: Output projection layer. + """ + if vocab_size is None: + vocab_size = self.params["vocab_size"] + if self.params["tensor_model_parallel_size"] > 1: + if self.params["backend"] == "pytorch" and self.inference: + tp_size = self.params["tensor_model_parallel_size"] + layer = nn.Linear(self.params["dim"], vocab_size // tp_size, bias=False).to(self.precision) + return layer + else: + layer = ColumnParallelLinear( + self.params["dim"], + vocab_size, + bias=False, + gather_output=False, + init_method=lambda x: x, + config=model_parallel, + ).to(self.precision) + return layer + else: + # No Tensor Parallelism + if self.params["backend"] == "pytorch": + return nn.Linear(self.params["dim"], vocab_size, bias=False).to(self.precision) + elif self.params["backend"] == "transformer_engine": + return LinearTE(self.params["dim"], vocab_size, bias=False).to(self.precision) + else: + raise ValueError("Unknown backend: " + self.params["backend"]) + + def _initialize_rope( + self, + ): + """ + Initialize the rotary position embedding. + + Returns: + tuple: (RotaryPositionEmbeddingTE, torch.Tensor) The RoPE module and the rotary position embeddings. + """ + rope = RotaryPositionEmbeddingTE(**self.rope_config) + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + rotary_pos_emb = rope.forward(seq_len=self.params["max_seq_len"], training_type=training_type) + return rope, rotary_pos_emb + + def _initialize_abs_pos_emb(self): + pos_emb = SinCosPosEmbAxisTE(**self.pos_emb_config) + training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None + abs_pos_emb = pos_emb.forward(training_type=training_type) + return pos_emb, abs_pos_emb + + def _broadcast_pos_emb(self, pos_emb, tp_group): + """ + Broadcast the position embeddings across the tensor parallel group. + + Args: + pos_emb (torch.Tensor): Position embeddings to broadcast. + tp_group (torch.distributed.ProcessGroup or None): Tensor parallel process group. + """ + if self.params["tensor_model_parallel_size"] > 1: + broadcast(pos_emb, min(get_process_group_ranks(tp_group)), group=tp_group) + + def _setup_peft(self): + """ + Set up Parameter Efficient Fine-Tuning (PEFT) by selectively freezing and unfreezing layers. + + This method configures the model for fine-tuning by: + 1. Freezing all parameters in the model. + 2. Unfreezing the embedding, normalization and output layers. + 3. Unfreezing the first and last (peft_last_n_layers - 1) transformer layers if peft_last_n_layers is set, + or unfreezing every n layers (flamingo style) if peft_every_n_layers is set. + """ + # Ensure only one of peft_last_n_layers and peft_every_n_layers is set + assert ( + self.peft_last_n_layers == 0 or self.peft_every_n_layers == 0 + ), "Only one of peft_last_n_layers and peft_every_n_layers can be set." + + # First, freeze all parameters + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze embedding, normalization and output layers + for param in self.tok_embeddings.parameters(): + param.requires_grad = True + for param in self.norm.parameters(): + param.requires_grad = True + for param in self.output.parameters(): + param.requires_grad = True + + # PEFT last n layers + if self.peft_last_n_layers > 0: + # Ensure peft_last_n_layers is at least 2 + assert self.peft_last_n_layers >= 2, "peft_last_n_layers must be at least 2" + + # Unfreeze specific transformer layers + total_layers = len(self.layers) + for i, layer in enumerate(self.layers): + if i == 0 or i >= total_layers - self.peft_last_n_layers + 1: + # Unfreeze the first layer and the last (peft_last_n_layers - 1) layers + for param in layer.parameters(): + param.requires_grad = True + + log.info( + f"PEFT setup complete. Trainable components: embeddings, un-embedding, normalization layer, " + f"first transformer layer, last {self.peft_last_n_layers - 1} transformer layers." + ) + # PEFT every n layers (flamingo style, e.g. every 4 layers = layer 0,1,2,4,5,6,... frozen, layer 3,7,11,... is trainable) + else: + trainable_layers = [] + for i, layer in enumerate(self.layers, 1): + if i % self.peft_every_n_layers == 0: + for param in layer.parameters(): + param.requires_grad = True + trainable_layers.append(i - 1) + + log.info( + f"PEFT setup complete. Trainable components: embeddings, un-embedding, normalization layer, " + f"every {self.peft_every_n_layers} transformer layers (layer idx {trainable_layers}; total {len(trainable_layers)} layers)." + ) + + def _setup_cross_attn_ft(self): + """ + Set up Cross Attention Fine-Tuning by selectively freezing and unfreezing layers. + + This method configures the model for fine-tuning by: + 1. Freezing all parameters in the model. + 2. Unfreezing the embedding, normalization and output layers. + 3. Unfreezing all the added cross-attention layers. + 4. If `finetune_layers_with_cross_attn` is True, unfreeze the transformer layers for layers with cross attention. + 5. If `finetune_layers_without_cross_attn` is True, unfreeze the transformer layers for layers without cross attention. + 6. If 'use_action_condition' is True, unfreeze the action embedding layers. + """ + assert self.has_cross_attention, "Must insert cross-attention layers for finetuning." + finetune_layer_num = 0 + + # First, freeze all parameters + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze embedding, normalization and output layers + for param in self.tok_embeddings.parameters(): + param.requires_grad = True + for param in self.norm.parameters(): + param.requires_grad = True + for param in self.output.parameters(): + param.requires_grad = True + + # Unfreeze all the added cross-attention layers + total_layers = len(self.layers) + for i, layer in enumerate(self.layers): + if i % self.ca_every_k_layers == 0: + if self.params["backend"] == "pytorch": + for param in layer.cross_attention.parameters(): + param.requires_grad = True + elif self.params["backend"] == "transformer_engine": + for param in layer.inter_attention.parameters(): + param.requires_grad = True + else: + raise ValueError("Unknown backend: " + self.params["backend"]) + + # Unfreeze the transformer layers for layers with cross attention + if self.finetune_layers_with_cross_attn: + for i, layer in enumerate(self.layers): + if i % self.ca_every_k_layers == 0: + for param in layer.parameters(): + param.requires_grad = True + finetune_layer_num += 1 + + # Unfreeze the transformer layers for layers without cross attention + if self.finetune_layers_without_cross_attn: + for i, layer in enumerate(self.layers): + if i % self.ca_every_k_layers != 0: + for param in layer.parameters(): + param.requires_grad = True + finetune_layer_num += 1 + + # Unfreeze the action embedding layers + if self.use_action_condition: + for param in self.action_embedding_layers.parameters(): + param.requires_grad = True + + log.info( + f"cross attention finetune setup complete. Trainable components: cross-attention layer, " + f"fully trainable transformer layer number is {finetune_layer_num}." + ) + + def enable_context_parallel(self, cp_group: ProcessGroup): + """ + Enable context parallelism for the transformer model. + + This method sets up context parallelism by configuring the context parallel group + and updating each transformer layer to support context parallelism. + + Args: + cp_group (ProcessGroup): The process group for context parallelism. + + Notes: + - Updates the model's context parallel group and size. + - Configures each transformer layer for context parallelism. + - Enables context parallelism for the rotary position embedding if using the transformer engine backend. + """ + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + # Set these attributes for spliting the data after embedding. + self.cp_group = cp_group + # Set these attributes for computing the loss. + self.cp_size = cp_size + + for layer_idx, layer in enumerate(self.layers): + if isinstance(layer, TransformerBlockTE): + layer.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + elif hasattr(layer, "module") and isinstance(layer.module, TransformerBlockTE): + layer.module.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + else: + log.warning(f"Layer {layer_idx} does not support context parallelism") + + def set_inference_flag(self, flag: bool): + """ + Set the inference flag for the transformer layers. + """ + log.info(f"Setting inference flag to {flag}") + self.inference = flag + if self.inference: + self.eval() + if self.params["backend"] == "pytorch": + for layer in self.layers: + layer.attention.set_inference_flag(flag) + elif self.params["backend"] == "transformer_engine": + for layer in self.layers: + layer.set_inference_flag(flag) + + self._maybe_change_sequence_parallel_status(enable=False) + + def _maybe_change_sequence_parallel_status(self, enable: bool): + """ + Change the sequence parallel status of the transformer layers. + """ + if enable and not self.sequence_parallel_enabled: + for name, module in self.named_modules(): + if hasattr(module, "sequence_parallel"): + assert isinstance( + module.sequence_parallel, bool + ), f"Invalid type of {name}: {type(module.sequence_parallel)}" + setattr(module, "sequence_parallel", True) + self.sequence_parallel_enabled = True + elif not enable and self.sequence_parallel_enabled: + for name, module in self.named_modules(): + if hasattr(module, "sequence_parallel"): + assert isinstance( + module.sequence_parallel, bool + ), f"Invalid type of {name}: {type(module.sequence_parallel)}" + setattr(module, "sequence_parallel", False) + self.sequence_parallel_enabled = False + + def forward( + self, + tokens: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, + token_embeddings: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + action: Optional[torch.Tensor] = None, + total_seq_len: Optional[int] = None, + return_hidden_states: bool = False, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the Transformer module. + + Args: + tokens (torch.Tensor, optional): The input tensor of token IDs. + input_pos (Optional[torch.Tensor]): The position of the current sequence. Used in inference with KV cache. PyTorch backend only. + inference_params (InferenceParams, optional): Parameters for inference. + token_embeddings (torch.Tensor, optional): Precomputed token embeddings. If provided, tokens should be None. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + action (Optional[torch.Tensor]): The robot action tensor for conditioning. + total_seq_len (Optional[int]): The total sequence length (before applying context parallelism). + return_hidden_states (bool): Whether to return hidden states. + Returns: + The output tensor after applying the transformer layers. + """ + + # Turn on/off sequence parallelism based on the training status + self._maybe_change_sequence_parallel_status(enable=self.training and self.params["sequence_parallel"]) + + # Token embeddings + assert ( + tokens is None or token_embeddings is None + ), "Either tokens or token_embeddings should be provided, not both." + + if token_embeddings is None: + seq_len = tokens.shape[1] + h = self.token_emb_dropout(self.tok_embeddings(tokens)) + else: + seq_len = token_embeddings.shape[1] + h = self.token_emb_dropout(token_embeddings) + + if mask is None: + # Create attention mask + mask = self._create_attention_mask(input_pos=input_pos) + + # Action embedding + if self.use_action_condition and action is not None: + assert self.action_embedding_mode == "mlp", f"Invalid action embedding mode: {self.action_embedding_mode}" + # change action type to bfloat16, of shape [batch_size, action_dim] + action = action.to(torch.bfloat16) + # action_emb shape: [batch_size, action_dim, action_embedding_dim] + action_emb = self.action_embedding_layers(action).unsqueeze(1).repeat(1, self.action_dim, 1) + + # Use action_emb as context + if self.params["concat_action_to_context"]: + context = torch.zeros( + (action_emb.shape[0], _T5_NUM_TOKENS, self.action_embedding_dim), device=h.device, dtype=h.dtype + ) + # context[:, -1, :] = action_emb[:, 0, :] # overwrite the last token with action_emb + context = torch.cat([context, action_emb[:, 0:1, :]], dim=1) + else: + context = action_emb # [batch_size, action_dim, action_embedding_dim] + + # Create context mask + if self.group_causal_mask_mode is not None: + num_temporal_groups = self.num_video_frames - 1 # number of latent frames + num_query_per_group = seq_len // num_temporal_groups # number of latent tokens per frame + num_key_per_group = self.action_dim // num_temporal_groups + context_mask = create_group_causal_attn_mask( + num_temporal_groups=num_temporal_groups, + num_query_per_group=num_query_per_group, + num_key_per_group=num_key_per_group, + mode=self.group_causal_mask_mode, + ) # [L (query), S (key)] + context_mask = context_mask.unsqueeze(0) # [1, L (query), S (key)] + context_mask = context_mask.repeat(context.shape[0], 1, 1) # [batch_size, L (query), S (key)] + context_mask = context_mask.to(context.device) + else: + context_mask = torch.ones( + (context.shape[0], context.shape[1]), device=context.device, dtype=torch.bool + ) # [batch_size, action_dim] + + # Prepare layer arguments + layer_kwargs = self._prepare_layer_kwargs( + total_seq_len=total_seq_len, + input_pos=input_pos, + mask=mask, + inference_params=inference_params, + context=context, + context_mask=context_mask, + ) + + # Apply transformer layers + for layer in self.layers: + if self.params["apply_abs_pos_emb"]: + h = self.apply_abs_pos_emb(h, input_pos=input_pos, total_seq_len=total_seq_len) + h = layer(h, **layer_kwargs) + + # Apply final layer normalization + h = self.norm(h) + if return_hidden_states: + return h + + # Output linear projection + output = self.output(h) + output = self.process_output(output) + return output + + def process_output(self, output: torch.Tensor) -> torch.Tensor: + """ + Adjusts the shape and layout of tensor based on tensor parallelism and attention input format. + + The function performs two operations: + 1. If the tensor model parallelism is enabled (`tensor_model_parallel_size > 1`), it gathers the tensor from + the tensor-parallel regions and reshapes it accordingly. + 2. If the attention input format is `"sbhd"` (Sequence, Batch, Hidden Dimension), it transposes the tensor + to the format `(Batch, Sequence, Hidden Dimension)` for further processing. + + Args: + output [torch.Tensor]: The tensor before modification. + + Returns: + output [torch.Tensor]: The tensor after modification. + + """ + if self.params["tensor_model_parallel_size"] > 1: + if self.params["backend"] == "pytorch" and self.inference: + # Use PyTorch all gather + output = funcol.all_gather_tensor( + output, gather_dim=-1, group=parallel_state.get_tensor_model_parallel_group() + ) + else: + # [*, *, hidden_dim // tp_size] --> [*, *, hidden_dim] + output = gather_from_tensor_model_parallel_region(output) + if self.attn_input_format == "sbhd": + # [seq_len, batch_size, hidden_dim] --> [batch_size, seq_len, hidden_dim] + output = output.transpose(0, 1).contiguous() + return output + + def _create_attention_mask(self, input_pos: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """ + Creates an attention mask for the transformer layers. + + Args: + input_pos[torch.Tensor]: The position of input sequence (used for inference only). + + Returns: + Optional[torch.Tensor]: The attention mask, or None for causal mask. + """ + + if self.backend == "pytorch" and self.inference: + assert input_pos is not None, "input_pos must be provided for inference" + mask = self.causal_mask[input_pos] + return mask + else: + return None # None means causal mask + + def _prepare_layer_kwargs( + self, + total_seq_len: Optional[int], + input_pos: Optional[torch.Tensor], + mask: Optional[torch.Tensor], + inference_params: Optional[InferenceParams], + context: Optional[torch.Tensor], + context_mask: Optional[torch.Tensor], + ) -> Dict[str, Any]: + """ + Prepares the keyword arguments for transformer layers. + + Args: + total_seq_len (Optional[int]): The total sequence length (before applying context parallelism). + seq_len (Optional[int]): The length of the input sequence. + input_pos (Optional[torch.Tensor]): The position of the current sequence. + mask (Optional[torch.Tensor]): The attention mask. + inference_params (Optional[InferenceParams]): Parameters for inference. + context (Optional[torch.Tensor]): The context tensor added via cross-attn. + context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. + + Returns: + Dict[str, Any]: A dictionary of keyword arguments for the transformer layers. + """ + if context is not None: + context = context.to(self.precision) + + if self.attn_input_format == "sbhd": + context = context.transpose(0, 1).contiguous() + if self.backend == "pytorch": + if isinstance(mask, torch.Tensor) and mask.ndim == 2: + mask = mask[None, None, :, :] + if isinstance(context_mask, torch.Tensor) and context_mask.ndim == 2: + context_mask = context_mask[None, None, :, :] + + layer_kwargs = { + "mask": mask, + "context": context, + "context_mask": context_mask, + } + + if self.backend == "pytorch": + layer_kwargs["input_pos"] = input_pos + layer_kwargs["rope"] = self.rope + elif self.backend == "transformer_engine": + rotary_pos_emb = self.rotary_pos_emb + try: + cp_size = parallel_state.get_context_parallel_world_size() + except (AssertionError, RuntimeError): + # Fallback if context parallel group isn't initialized + cp_size = 1 + log.warning("Context parallel group not initialized, falling back to size 1") + else: + cp_size = 1 + if cp_size > 1: + assert input_pos is None, "input_pos must be None for context parallelism" + rotary_pos_emb = rotary_pos_emb[:total_seq_len] + rotary_pos_emb = get_pos_emb_on_this_cp_rank(rotary_pos_emb, 0) + + layer_kwargs["rotary_pos_emb"] = rotary_pos_emb + layer_kwargs["inference_params"] = inference_params + + return layer_kwargs + + def apply_abs_pos_emb( + self, x: torch.Tensor, input_pos: int = None, total_seq_len: Optional[int] = None + ) -> torch.Tensor: + """ + Applies the absolute position embeddings to the input tensor. + """ + abs_pos_emb = self.abs_pos_emb + if total_seq_len is not None: + # Truncate the absolute position embeddings to the total sequence length + abs_pos_emb = ( + abs_pos_emb[:total_seq_len, :, :] + if self.attn_input_format == "sbhd" + else abs_pos_emb[:, :total_seq_len, :] + ) + cp_size = parallel_state.get_context_parallel_world_size() if self.training else 1 + if cp_size > 1: + assert input_pos is None + seq_dim = 0 if self.attn_input_format == "sbhd" else 1 + abs_pos_emb = get_pos_emb_on_this_cp_rank(abs_pos_emb, seq_dim=seq_dim) + if self.attn_input_format == "sbhd": + if self.sequence_parallel_enabled: + # Training + assert input_pos is None, "input_pos must be None when training with sequence parallelism" + abs_pos_emb = get_pos_emb_on_this_sptp_rank(abs_pos_emb, seq_dim=0) + else: + # Inference or Evaluation + abs_pos_emb = abs_pos_emb[input_pos, :, :] if input_pos is not None else abs_pos_emb + else: + abs_pos_emb = abs_pos_emb[:, input_pos, :] if input_pos is not None else abs_pos_emb + return x + abs_pos_emb + + @torch.no_grad() + def expand_vocab( + self, new_vocab_size: int, init_method: str = "gaussian", multiple_of=64, expand_output_layer=True + ): + """ + Expands the vocabulary of the model to the new size. + + Args: + new_vocab_size (int): The new vocabulary size. + init_method (str): The initialization method for new embeddings. + Can be "zero" or "gaussian". Default is "gaussian". + multiple_of (int): The new vocabulary size must be a multiple of this value. Defaults to 64 to fully + leverage the power of NVIDIA TensorCore (source 1: https://x.com/karpathy/status/1621578354024677377, + source 2: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc) + expand_output_layer (bool): Whether to also expand the output layer. Defaults to True. + + Returns: + None + """ + + tp_size = self.params["tensor_model_parallel_size"] + if new_vocab_size <= self.vocab_size: + raise ValueError( + f"New vocabulary size ({new_vocab_size}) must be " f"larger than current size ({self.vocab_size})" + ) + if new_vocab_size % multiple_of != 0: + log.critical(f"New vocabulary size must be a multiple of {multiple_of}. Obtained {new_vocab_size}.") + new_vocab_size = (new_vocab_size // multiple_of + 1) * multiple_of + log.critical(f"Rounded vocabulary size to {new_vocab_size}.") + # Resize token embeddings + old_embeddings = self.tok_embeddings + old_embeddings_requires_grad = old_embeddings.weight.requires_grad + tensor_kwargs = {"device": old_embeddings.weight.device, "dtype": old_embeddings.weight.dtype} + self.tok_embeddings = self._create_token_embeddings( + model_parallel=self.model_parallel, vocab_size=new_vocab_size + ).to(**tensor_kwargs) + # Initialize new embeddings + if init_method not in ["zero", "gaussian"]: + raise ValueError(f"Unknown initialization method: {init_method}") + # The default initialization of nn.Embedding is Gaussian, so we don't need to do anything + # if init_method == "gaussian". Only if init_method == "zero", we need to zero out the new embeddings. + if init_method == "zero": + self.tok_embeddings.weight.data[self.vocab_size // tp_size :].zero_() + + # Copy old embeddings + log.info( + f"old_embeddings: {old_embeddings.weight.data.shape}, new_embeddings: {self.tok_embeddings.weight.data.shape}, vocab_size: {self.vocab_size}" + ) + self.tok_embeddings.weight.data[: self.vocab_size // tp_size] = old_embeddings.weight.data + self.tok_embeddings.weight.requires_grad = old_embeddings_requires_grad + # Resize output layer + old_output = self.output + old_output_requires_grad = old_output.weight.requires_grad + self.output = self._create_output_projection( + self.model_parallel, vocab_size=new_vocab_size if expand_output_layer else None + ) + + # Initialize new output weights + if init_method == "zero": + self.output.weight.data[self.vocab_size // tp_size :].zero_() + elif init_method == "gaussian": + # Follows the parameter initialization in TorchTitan: + # https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py + final_out_std = self.params["dim"] ** -0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + # Copy old output weights + self.output.weight.data[: self.vocab_size // tp_size] = old_output.weight.data + self.output.weight.requires_grad = old_output_requires_grad + + # Update vocab size + self.vocab_size = new_vocab_size + log.critical(f"Expanded vocabulary size to {new_vocab_size}") + + def init_weights(self): + """ + [Note: On ``init_weights`` vs. ``reset_parameters`` (copied from github.com/pytorch/torchtitan)] + Modules may define ``reset_parameters`` to initialize parameter values. ``reset_parameters`` is meant to only + initialize directly owned parameters/buffers, not those of their child modules, and it can be used to give the + initial values for these tensors. Separately, users may want custom initialization for their modules, different + from that in ``reset_parameters``. For this, we define ``init_weights``. We only call it in the constructor of + this ``Transformer`` root module to avoid reinitializing tensors. + """ + + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers: + layer.init_weights() + if self.backend == "pytorch": + self.norm.reset_parameters() + elif self.backend == "transformer_engine": + nn.init.ones_(self.norm.weight) + else: + raise ValueError(f"Unknown backend: {self.backend}") + final_out_std = self.params["dim"] ** -0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + if self.use_action_condition: + for layer in self.action_embedding_layers: + if isinstance(layer, nn.Linear): + nn.init.xavier_uniform_(layer.weight) + nn.init.zeros_(layer.bias) + + def state_dict(self, *args, **kwargs): + """ + Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). + """ + state_dict = super().state_dict(*args, **kwargs) + return process_state_dict(state_dict) + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + """ + Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by + TransformerEngine for FP8). + """ + state_dict = process_state_dict(state_dict) + missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) + if strict: + actual_missing_keys = [] + for key in missing_keys: + if not any(substring in key for substring in substrings_to_ignore): + actual_missing_keys.append(key) + if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: + raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") + missing_keys = actual_missing_keys + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def on_after_backward(self, *args, **kwargs): + """ + All-reduce layernorm grads for tensor/sequence parallelism. + Reference implementation: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/finalize_model_grads.py + """ + allreduce_layernorm_grads( + [self], + tensor_model_parallel_size=self.params["tensor_model_parallel_size"], + sequence_parallel=self.params["sequence_parallel"], + ) + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + """Hook before zero_grad() is called. + + Args: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + iteration (int): Current iteration number. + """ + if self.params["sync_1d_parameters"]: + if self.params["tensor_model_parallel_size"] > 1: + sync_1d_parameters(self, process_group=parallel_state.get_tensor_model_parallel_group()) + if self.params["context_parallel_size"] > 1: + sync_1d_parameters(self, process_group=parallel_state.get_context_parallel_group()) diff --git a/cosmos_predict1/autoregressive/utils/__init__.py b/cosmos_predict1/autoregressive/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/autoregressive/utils/checkpoint.py b/cosmos_predict1/autoregressive/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..fb49e9e03173caccc8473c513d70124dc371d38e --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/checkpoint.py @@ -0,0 +1,594 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch + +from cosmos_predict1.utils import log + +# Substrings to ignore when processing state dicts +substrings_to_ignore = [ + "_extra_state", # Extra states (BytesIO type) added by TransformerEngine for FP8 handling +] + + +def identify_checkpoint_backend(state_dict: dict[str, torch.Tensor]) -> str: + """ + Identify the backend of the checkpoint (PyTorch or TransformerEngine) + + Args: + state_dict (dict[str, torch.Tensor]): The state dict to check + + Returns: + str: The backend of the checkpoint + """ + for key in state_dict.keys(): + if "self_attention.layernorm_qkv.query_weight" in key: + return "transformer_engine" + elif "attention.wq.weight" in key: + return "pytorch" + raise ValueError("Could not identify the backend of the checkpoint") + + +def get_partial_state_dict( + state_dict: dict[str, torch.Tensor], + prefix: str, +) -> dict[str, torch.Tensor]: + """ + Get a partial state dict with keys starting with the given prefix + """ + return {k: v for k, v in state_dict.items() if k.startswith(prefix)} + + +def process_state_dict( + state_dict: dict[str, torch.Tensor], + device: str = None, + dtype: torch.dtype = None, + prefix_to_remove: Optional[str] = None, +) -> dict[str, torch.Tensor]: + """ + - Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8) + - Move tensors to specified device and dtype if provided + + Args: + state_dict (dict[str, torch.Tensor]): The state dict to process + device (str, optional): The device to move tensors to. Defaults to None. + dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None. + prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None. + + Returns: + dict[str, torch.Tensor]: The processed state dict + """ + new_state_dict = {} + tensor_kwargs = {} + if device is not None: + tensor_kwargs["device"] = device + if dtype is not None: + tensor_kwargs["dtype"] = dtype + + for key, value in state_dict.items(): + # Check if any of the substrings to ignore are in the key + skip = False + for substr in substrings_to_ignore: + if substr in key: + skip = True + break + if skip: + continue + if len(tensor_kwargs) > 0: + value = value.to(**tensor_kwargs) + if prefix_to_remove is not None and key.startswith(prefix_to_remove): + key = key[len(prefix_to_remove) :] + new_state_dict[key] = value + return new_state_dict + + +def obtain_tensor_parallel_state_dict( + whole_model_state_dict: dict[str, torch.Tensor], + tensor_parallel_size: int, + tensor_parallel_rank: int, + model_config, + target_backend: str = None, +) -> dict[str, torch.Tensor]: + """ + Obtain the tensor parallel state dict shard for the current rank. + + Args: + whole_model_state_dict (dict[str, torch.Tensor]): The complete model state dict. + tensor_parallel_size (int): The number of tensor parallel devices. + tensor_parallel_rank (int): The rank of the current tensor parallel device. + model_config: The model configuration. + target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used. + + Returns: + dict[str, torch.Tensor]: The updated state dict shard for the current tensor parallel rank. + """ + new_state_dict_shard = {} + whole_model_state_dict = process_state_dict(whole_model_state_dict) + source_backend = identify_checkpoint_backend(whole_model_state_dict) + if source_backend != "pytorch": + # Convert the checkpoint to PyTorch backend for checkpoint sharding + whole_model_state_dict = maybe_convert_checkpoint_to_backend( + whole_model_state_dict, target_backend="pytorch", model_config=model_config, source_backend=source_backend + ) + + n_heads = model_config["n_heads"] + n_kv_heads = model_config["n_kv_heads"] + dim = model_config["dim"] + context_dim = model_config["context_dim"] + for key, value in whole_model_state_dict.items(): + prefix = "model." if key.startswith("model.") else "" # LLM's model prefix + prefix = "transformer." if key.startswith("transformer.") else prefix # VIT's model prefix + key = key.replace(prefix, "") + if key.startswith("layers."): + layer_index = int(key.split("layers.")[1].split(".")[0]) + if layer_index >= model_config["n_layers"]: + log.warning( + f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer." + ) + continue + if ".attention.wq.weight" in key or "cross_attention.wq.weight" in key: + value = torch.chunk(value.view(n_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank] + value = value.reshape(-1, dim) + elif ".attention.wk.weight" in key or ".attention.wv.weight" in key: + value = torch.chunk(value.view(n_kv_heads, -1, dim), tensor_parallel_size, dim=0)[tensor_parallel_rank] + value = value.reshape(-1, dim) + elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key: + assert context_dim is not None + value = torch.chunk(value.view(n_kv_heads, -1, context_dim), tensor_parallel_size, dim=0)[ + tensor_parallel_rank + ] + value = value.reshape(-1, context_dim) + elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key: + value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank] + elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key: + value = torch.chunk(value, tensor_parallel_size, dim=1)[tensor_parallel_rank] + else: + # Handle non-layer weights + if key == "tok_embeddings.weight" or key == "output.weight" or "medusa_head" in key: + value = torch.chunk(value, tensor_parallel_size, dim=0)[tensor_parallel_rank] + new_state_dict_shard[prefix + key] = value + + if target_backend is None: + target_backend = source_backend + + new_state_dict_shard = maybe_convert_checkpoint_to_backend( + new_state_dict_shard, + target_backend=target_backend, + model_config=model_config, + is_tensor_parallel_shard=True, + tensor_parallel_size=tensor_parallel_size, + ) + + return new_state_dict_shard + + +def merge_tensor_parallel_state_dicts( + state_dict_shards: list[dict[str, torch.Tensor]], + model_config, + target_backend: str = None, +) -> dict[str, torch.Tensor]: + """ + Merge tensor parallel state dict shards into a whole model state dict. + + Args: + state_dict_shards (List[Dict[str, torch.Tensor]]): The list of state dict shards to merge. + model_config: The model configuration. + target_backend (str, optional): The target backend format ('pytorch', 'transformer_engine', or 'huggingface'). If not specified, the source backend will be used. + + Returns: + Dict[str, torch.Tensor]: The merged state dict. + """ + state_dict_shards = [process_state_dict(shard, device="cpu") for shard in state_dict_shards] + tensor_parallel_size = len(state_dict_shards) + source_backend = identify_checkpoint_backend(state_dict_shards[0]) + if source_backend != "pytorch": + log.critical(f"Converting from {source_backend} to PyTorch backend for tensor parallel checkpoint merging.") + state_dict_shards = [ + maybe_convert_checkpoint_to_backend( + shard, + target_backend="pytorch", + model_config=model_config, + source_backend=source_backend, + is_tensor_parallel_shard=True, + tensor_parallel_size=tensor_parallel_size, + ) + for shard in state_dict_shards + ] + + n_heads = model_config["n_heads"] + n_kv_heads = model_config["n_kv_heads"] + n_local_heads = n_heads // tensor_parallel_size + n_local_kv_heads = n_kv_heads // tensor_parallel_size + dim = model_config["dim"] + context_dim = model_config["context_dim"] + head_dim = model_config["head_dim"] + if head_dim is None: + head_dim = model_config["dim"] // model_config["n_heads"] + query_dim = head_dim * n_heads + key_value_dim = head_dim * n_kv_heads + merged_state_dict = {} + + for key in state_dict_shards[0].keys(): + prefix = "model." if key.startswith("model.") else "" + key_without_prefix = key[len(prefix) :] + if key_without_prefix.startswith("layers."): + layer_index = int(key_without_prefix.split("layers.")[1].split(".")[0]) + if layer_index >= model_config["n_layers"]: + log.warning( + f"Layer index {layer_index} is greater than the number of layers {model_config['n_layers']}. Skipping this layer." + ) + continue + if key_without_prefix == "tok_embeddings.weight" or key_without_prefix == "output.weight": + merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0) + elif ".attention.wq.weight" in key or "cross_attention.wq.weight" in key: + chunks = [shard[key].view(n_local_heads, head_dim, dim) for shard in state_dict_shards] + merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(query_dim, dim) + elif ".attention.wk.weight" in key or ".attention.wv.weight" in key: + chunks = [shard[key].view(n_local_kv_heads, head_dim, dim) for shard in state_dict_shards] + merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, dim) + elif "cross_attention.wk.weight" in key or "cross_attention.wv.weight" in key: + chunks = [shard[key].view(n_local_kv_heads, head_dim, context_dim) for shard in state_dict_shards] + merged_state_dict[key] = torch.cat(chunks, dim=0).reshape(key_value_dim, context_dim) + elif "feed_forward.w1.weight" in key or "feed_forward.w3.weight" in key or "medusa_head" in key: + merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=0) + elif "feed_forward.w2.weight" in key or ".attention.wo.weight" in key or "cross_attention.wo.weight" in key: + merged_state_dict[key] = torch.cat([shard[key] for shard in state_dict_shards], dim=1) + else: + avg_tensor = torch.stack([shard[key] for shard in state_dict_shards]).mean(dim=0) + # make sure shard-0 is close to the average tensor + assert torch.allclose(state_dict_shards[0][key], avg_tensor, atol=5e-2, rtol=0.1), ( + f"Shard-0 tensor {key} is not close to the average tensor. " + f"Max diff: {torch.max(torch.abs(state_dict_shards[0][key] - avg_tensor))}, " + ) + merged_state_dict[key] = avg_tensor + assert "norm" in key, f"Assumed the key {key} is a norm layer, which should be the same across shards." + + if target_backend is None: + target_backend = source_backend + return maybe_convert_checkpoint_to_backend( + merged_state_dict, target_backend=target_backend, model_config=model_config + ) + + +def te_to_pytorch_state_dict( + te_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1 +) -> Dict[str, torch.Tensor]: + """ + Convert a TransformerEngine state dict to PyTorch state dict + + Args: + te_state_dict (Mapping[str, torch.Tensor]): The TransformerEngine state dict + model_config: The model configuration + tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard). + + Returns: + Mapping[str, torch.Tensor]: The PyTorch state dict + """ + + if hasattr(model_config, "asdict"): + model_config = model_config.asdict() + + pytorch_state_dict = {} + replacement_rules = [ + # Self-attention modules + (".self_attention.layernorm_qkv.layer_norm_weight", ".attention_norm.weight"), + (".self_attention.layernorm_qkv.query_weight", ".attention.wq.weight"), + (".self_attention.layernorm_qkv.key_weight", ".attention.wk.weight"), + (".self_attention.layernorm_qkv.value_weight", ".attention.wv.weight"), + (".self_attention.proj.weight", ".attention.wo.weight"), + (".self_attention.", ".attention."), # Handle the rest modules such as q_norm and k_norm + # MLP modules + (".layernorm_mlp.layer_norm_weight", ".ffn_norm.weight"), + (".layernorm_mlp.fc2_weight", ".feed_forward.w2.weight"), + # Cross-attention modules + (".inter_attention.layernorm_query.query_weight", ".cross_attention.wq.weight"), + (".inter_attention.key_value.key_weight", ".cross_attention.wk.weight"), + (".inter_attention.key_value.value_weight", ".cross_attention.wv.weight"), + (".inter_attention.proj.weight", ".cross_attention.wo.weight"), + (".inter_attention.layernorm_query.layer_norm_weight", ".cross_attention_norm.weight"), + (".inter_attention.", ".cross_attention."), # Handle the rest modules such as q_norm and k_norm + ] + head_dim = model_config["head_dim"] + if head_dim is None: + head_dim = model_config["dim"] // model_config["n_heads"] + for old_key, value in te_state_dict.items(): + new_key = old_key + for old_substr, new_substr in replacement_rules: + if old_substr in new_key: + new_key = new_key.replace(old_substr, new_substr) + break + + # Handle the fused w1 and w3 case + if "layernorm_mlp.fc1_weight" in old_key: + fused_weight = value + split_point = fused_weight.shape[0] // 2 + w1_weight = fused_weight[:split_point] + w3_weight = fused_weight[split_point:] + + w1_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w1.weight") + w3_key = new_key.replace("layernorm_mlp.fc1_weight", "feed_forward.w3.weight") + + pytorch_state_dict[w1_key] = w1_weight + pytorch_state_dict[w3_key] = w3_weight + else: + if model_config["pytorch_rope_version"] == "v1": + # If the model use qk normalization, we will use the same PyTorch RoPE operations as the TE version. + # Thus, we do not need to permute the weights. + if "query_weight" in old_key: + value = inverse_permute_weight( + value, + n_heads=model_config["n_heads"] // tensor_parallel_size, + dim1=head_dim * model_config["n_heads"] // tensor_parallel_size, + dim2=model_config["dim"], + ) + elif "key_weight" in old_key: + value = inverse_permute_weight( + value, + n_heads=model_config["n_kv_heads"] // tensor_parallel_size, + dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size, + dim2=model_config["context_dim"] if "inter_attention" in old_key else model_config["dim"], + ) + pytorch_state_dict[new_key] = value + + return pytorch_state_dict + + +def pytorch_to_te_state_dict( + pytorch_state_dict: Dict[str, torch.Tensor], model_config, tensor_parallel_size: int = 1 +) -> Dict[str, torch.Tensor]: + """ + Convert a PyTorch state dict to TransformerEngine state dict + + Args: + pytorch_state_dict (Mapping[str, torch.Tensor]): The PyTorch state dict + model_config: The model configuration + tensor_parallel_size (int): The tensor parallel size. Defaults to 1 (i.e., not a tensor parallel shard). + + Returns: + Mapping[str, torch.Tensor]: The TransformerEngine + """ + + if hasattr(model_config, "asdict"): + model_config = model_config.asdict() + + te_state_dict = {} + + replacement_rules = [ + # Self-attention modules + (".attention_norm.weight", ".self_attention.layernorm_qkv.layer_norm_weight"), + (".attention.wq.weight", ".self_attention.layernorm_qkv.query_weight"), + (".attention.wk.weight", ".self_attention.layernorm_qkv.key_weight"), + (".attention.wv.weight", ".self_attention.layernorm_qkv.value_weight"), + (".attention.wo.weight", ".self_attention.proj.weight"), + (".attention.", ".self_attention."), + # MLP modules + (".ffn_norm.weight", ".layernorm_mlp.layer_norm_weight"), + (".feed_forward.w2.weight", ".layernorm_mlp.fc2_weight"), + # Cross-attention modules + (".cross_attention_norm.weight", ".inter_attention.layernorm_query.layer_norm_weight"), + (".cross_attention.wq.weight", ".inter_attention.layernorm_query.query_weight"), + (".cross_attention.wk.weight", ".inter_attention.key_value.key_weight"), + (".cross_attention.wv.weight", ".inter_attention.key_value.value_weight"), + (".cross_attention.wo.weight", ".inter_attention.proj.weight"), + (".cross_attention.", ".inter_attention."), + ] + head_dim = model_config["head_dim"] + if head_dim is None: + head_dim = model_config["dim"] // model_config["n_heads"] + for old_key, value in pytorch_state_dict.items(): + new_key = old_key + for new_substr, old_substr in replacement_rules: + if new_substr in new_key: + new_key = new_key.replace(new_substr, old_substr) + break + + # Handle the split w1 and w3 case + if "feed_forward.w1.weight" in old_key: + w1_weight = value + w3_key = old_key.replace("feed_forward.w1.weight", "feed_forward.w3.weight") + if w3_key in pytorch_state_dict: + w3_weight = pytorch_state_dict[w3_key] + fused_weight = torch.cat([w1_weight, w3_weight], dim=0) + new_key = new_key.replace("feed_forward.w1.weight", "layernorm_mlp.fc1_weight") + te_state_dict[new_key] = fused_weight + else: + te_state_dict[new_key] = value + elif "feed_forward.w3.weight" in old_key: + # Skip w3 weights as they're handled with w1 + continue + else: + if model_config["pytorch_rope_version"] == "v1": + # If the model use qk normalization, we will use the same PyTorch RoPE operations as the TE version. + # Thus, we do not need to permute the weights. + if "attention.wq" in old_key: + value = permute_weight( + value, + n_heads=model_config["n_heads"] // tensor_parallel_size, + dim1=head_dim * model_config["n_heads"] // tensor_parallel_size, + dim2=model_config["dim"], + ) + elif "attention.wk" in old_key: + value = permute_weight( + value, + n_heads=model_config["n_kv_heads"] // tensor_parallel_size, + dim1=head_dim * model_config["n_kv_heads"] // tensor_parallel_size, + dim2=model_config["context_dim"] if "cross_attention" in old_key else model_config["dim"], + ) + te_state_dict[new_key] = value + + return te_state_dict + + +def permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: + """ + Helper function for converting checkpoints from PyTorch to TransformerEngine + Permute the query weight or key weight of each attention layer + Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py + + Args: + w (torch.Tensor): The weight tensor to permute + n_heads (int): The number of attention heads + dim1 (int): The first dimension of the weight tensor + dim2 (int): The second dimension of the weight tensor + + Returns: + torch.Tensor: The permuted weight tensor + """ + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + +def inverse_permute_weight(w: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: + """ + Helper function for converting checkpoints from TransformerEngine to PyTorch + Permute the query weight or key weight of each attention layer + + Args: + w (torch.Tensor): The weight tensor to permute + n_heads (int): The number of attention heads + dim1 (int): The first dimension of the weight tensor + dim2 (int): The second dimension of the weight tensor + + Returns: + torch.Tensor: The permuted weight tensor + """ + return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + +def pytorch_to_hf_state_dict( + state_dict: Dict[str, torch.Tensor], model_config: Dict[str, Any], tensor_parallel_size: int = 1 +) -> Dict[str, torch.Tensor]: + """ + Convert a PyTorch state dict to HuggingFace format for LLM models. + + Args: + state_dict (Mapping[str, torch.Tensor]): + The original PyTorch model's state dictionary. + This is a mapping where keys are layer names and values are the corresponding PyTorch tensors + containing the model weights. + + model_config (Mapping[str, Any]): + The configuration of the model. This dictionary contains parameters such as: + - n_layers: (int) The number of transformer layers. + - n_heads: (int) The number of attention heads. + - dim: (int) The hidden size of the model. + - n_kv_heads: (int, optional) The number of key-value heads for multi-query attention. + + Returns: + Mapping[str, torch.Tensor]: + The converted HuggingFace state dictionary. This dictionary maps HuggingFace transformer-compatible + layer names to the corresponding model weights. + """ + not_supported_key_substrings = ["cross_attention", "q_norm", "k_norm"] + for key in state_dict.keys(): + if any(substr in key for substr in not_supported_key_substrings): + raise ValueError(f"Key {key} is not supported in HuggingFace format.") + assert tensor_parallel_size == 1, "Tensor parallel size > 1 is not supported for HuggingFace model export." + + hf_state_dict = {} + + n_layers = model_config["n_layers"] + n_heads = model_config["n_heads"] + dim = model_config["dim"] + head_dim = model_config["head_dim"] + if head_dim is None: + head_dim = model_config["dim"] // model_config["n_heads"] + + num_key_value_heads = model_config.get("n_kv_heads", n_heads) + key_value_dim = head_dim * num_key_value_heads + + for layer_i in range(n_layers): + pt_prefix = f"layers.{layer_i}." + hf_prefix = f"model.layers.{layer_i}." + + wq = state_dict[f"{pt_prefix}attention.wq.weight"] + wk = state_dict[f"{pt_prefix}attention.wk.weight"] + if model_config["pytorch_rope_version"] == "v1": + wq = permute_weight( + wq, + n_heads=n_heads, + dim1=dim, + dim2=dim, + ) + wk = permute_weight( + wk, + n_heads=num_key_value_heads, + dim1=key_value_dim, + dim2=dim, + ) + hf_state_dict[f"{hf_prefix}self_attn.q_proj.weight"] = wq + hf_state_dict[f"{hf_prefix}self_attn.k_proj.weight"] = wk + hf_state_dict[f"{hf_prefix}self_attn.v_proj.weight"] = state_dict[f"{pt_prefix}attention.wv.weight"] + hf_state_dict[f"{hf_prefix}self_attn.o_proj.weight"] = state_dict[f"{pt_prefix}attention.wo.weight"] + hf_state_dict[f"{hf_prefix}mlp.gate_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w1.weight"] + hf_state_dict[f"{hf_prefix}mlp.down_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w2.weight"] + hf_state_dict[f"{hf_prefix}mlp.up_proj.weight"] = state_dict[f"{pt_prefix}feed_forward.w3.weight"] + hf_state_dict[f"{hf_prefix}input_layernorm.weight"] = state_dict[f"{pt_prefix}attention_norm.weight"] + hf_state_dict[f"{hf_prefix}post_attention_layernorm.weight"] = state_dict[f"{pt_prefix}ffn_norm.weight"] + + # Add non-layer weights + hf_state_dict["model.embed_tokens.weight"] = state_dict["tok_embeddings.weight"] + hf_state_dict["model.norm.weight"] = state_dict["norm.weight"] + hf_state_dict["lm_head.weight"] = state_dict["output.weight"] + + return hf_state_dict + + +def maybe_convert_checkpoint_to_backend( + state_dict: Dict[str, torch.Tensor], + target_backend: str, + model_config, + source_backend: str = None, + is_tensor_parallel_shard: bool = False, + tensor_parallel_size: int = None, +): + """ + Identify the backend of the checkpoint and convert to the target backend if necessary. + + This function checks the current backend of the state_dict and converts it to the target backend + if they don't match. It supports conversions between PyTorch, TransformerEngine, and HuggingFace backends. + + Args: + state_dict (Dict[str, torch.Tensor]): The model state dictionary to convert. + target_backend (str): The desired backend format ('pytorch', 'transformer_engine', or 'huggingface'). + model_config: Configuration of the model, used in conversion process. + source_backend (str, optional): The current backend of the state_dict. If not specified, the function will identify the backend. + is_tensor_parallel_shard (bool, optional): Whether the state_dict is a tensor parallel shard. Defaults to False. + tensor_parallel_size (int, optional): The tensor parallel size. If not specified, the model_config will be modified. + Returns: + Dict[str, torch.Tensor]: The converted state dictionary in the target backend format. + + Raises: + ValueError: If the conversion between the identified backend and target backend is not supported. + """ + # Identify the current backend of the checkpoint + state_dict = process_state_dict(state_dict) # Remove unnecessary keys + if source_backend is None: + source_backend = identify_checkpoint_backend(state_dict) + if source_backend == target_backend: + return state_dict + else: + if tensor_parallel_size is None: + tensor_parallel_size = model_config["tensor_parallel_size"] if is_tensor_parallel_shard else 1 + # Convert to target backend + if source_backend == "pytorch" and target_backend == "transformer_engine": + return pytorch_to_te_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) + elif source_backend == "transformer_engine" and target_backend == "pytorch": + return te_to_pytorch_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) + elif source_backend == "pytorch" and target_backend == "huggingface": + return pytorch_to_hf_state_dict(state_dict, model_config, tensor_parallel_size=tensor_parallel_size) + else: + raise ValueError(f"Conversion from {source_backend} to {target_backend} is not supported.") diff --git a/cosmos_predict1/autoregressive/utils/inference.py b/cosmos_predict1/autoregressive/utils/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..bca159c50db62a6938b7dbe7041423958f014c3c --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/inference.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import math +import os +from pathlib import Path +from typing import List + +import numpy as np +import torch +import torchvision +from PIL import Image + +from cosmos_predict1.autoregressive.configs.inference.inference_config import SamplingConfig +from cosmos_predict1.utils import log + +_IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"] +_VIDEO_EXTENSIONS = [".mp4"] +_SUPPORTED_CONTEXT_LEN = [1, 9] # Input frames +NUM_TOTAL_FRAMES = 33 + + +def add_common_arguments(parser): + """Add common command line arguments. + + Args: + parser (ArgumentParser): Argument parser to add arguments to + """ + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--video_save_name", + type=str, + default="output", + help="Output filename for generating a single video", + ) + parser.add_argument("--video_save_folder", type=str, default="outputs/", help="Output folder for saving videos") + parser.add_argument( + "--input_image_or_video_path", + type=str, + help="Input path for input image or video", + ) + parser.add_argument( + "--batch_input_path", + type=str, + help="Input folder containing all input images or videos", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=9, + help="Number of input frames for world generation", + choices=_SUPPORTED_CONTEXT_LEN, + ) + parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling") + parser.add_argument("--top_p", type=float, default=0.8, help="Top-p value for sampling") + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs used to run inference in parallel.") + parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder") + parser.add_argument( + "--offload_guardrail_models", + action="store_true", + help="Offload guardrail models after inference", + ) + parser.add_argument( + "--offload_diffusion_decoder", + action="store_true", + help="Offload diffusion decoder after inference", + ) + parser.add_argument( + "--offload_ar_model", + action="store_true", + help="Offload AR model after inference", + ) + parser.add_argument( + "--offload_tokenizer", + action="store_true", + help="Offload discrete tokenizer model after inference", + ) + parser.add_argument( + "--disable_guardrail", + action="store_true", + help="Disable guardrail models", + ) + + +def validate_args(args: argparse.Namespace, inference_type: str): + """Validate command line arguments for base and video2world generation.""" + assert inference_type in [ + "base", + "video2world", + ], "Invalid inference_type, must be 'base' or 'video2world'" + if args.input_type in ["image", "text_and_image"] and args.num_input_frames != 1: + args.num_input_frames = 1 + log.info(f"Set num_input_frames to 1 for {args.input_type} input") + + if args.num_input_frames == 1: + if "4B" in args.ar_model_dir: + log.warning( + "The failure rate for 4B model with image input is ~15%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." + ) + elif "5B" in args.ar_model_dir: + log.warning( + "The failure rate for 5B model with image input is ~7%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." + ) + + # Validate prompt/image/video args for single or batch generation + assert ( + args.input_image_or_video_path or args.batch_input_path + ), "--input_image_or_video_path or --batch_input_path must be provided." + if inference_type == "video2world" and (not args.batch_input_path): + assert args.prompt, "--prompt is required for single video generation." + args.data_resolution = [640, 1024] + + # Create output folder + Path(args.video_save_folder).mkdir(parents=True, exist_ok=True) + + sampling_config = SamplingConfig( + echo=True, + temperature=args.temperature, + top_p=args.top_p, + compile_sampling=True, + ) + return sampling_config + + +def resize_input(video: torch.Tensor, resolution: list[int]): + r""" + Function to perform aspect ratio preserving resizing and center cropping. + This is needed to make the video into target resolution. + Args: + video (torch.Tensor): Input video tensor + resolution (list[int]): Data resolution + Returns: + Cropped video + """ + + orig_h, orig_w = video.shape[2], video.shape[3] + target_h, target_w = resolution + + scaling_ratio = max((target_w / orig_w), (target_h / orig_h)) + resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w))) + video_resized = torchvision.transforms.functional.resize(video, resizing_shape) + video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution) + return video_cropped + + +def load_image_from_list(flist, data_resolution: List[int]) -> dict: + """ + Function to load images from a list of image paths. + Args: + flist (List[str]): List of image paths + data_resolution (List[int]): Data resolution + Returns: + Dict containing input images + """ + all_videos = dict() + for img_path in flist: + ext = os.path.splitext(img_path)[1] + if ext in _IMAGE_EXTENSIONS: + # Read the image + img = Image.open(img_path) + + # Convert to tensor + img = torchvision.transforms.functional.to_tensor(img) + static_vid = img.unsqueeze(0).repeat(NUM_TOTAL_FRAMES, 1, 1, 1) + static_vid = static_vid * 2 - 1 + + log.debug( + f"Resizing input image of shape ({static_vid.shape[2]}, {static_vid.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" + ) + static_vid = resize_input(static_vid, data_resolution) + fname = os.path.basename(img_path) + all_videos[fname] = static_vid.transpose(0, 1).unsqueeze(0) + + return all_videos + + +def read_input_images(batch_input_path: str, data_resolution: List[int]) -> dict: + """ + Function to read input images from a JSONL file. + + Args: + batch_input_path (str): Path to JSONL file containing visual input paths + data_resolution (list[int]): Data resolution + + Returns: + Dict containing input images + """ + # Read visual inputs from JSONL + flist = [] + with open(batch_input_path, "r") as f: + for line in f: + data = json.loads(line.strip()) + flist.append(data["visual_input"]) + + return load_image_from_list(flist, data_resolution=data_resolution) + + +def read_input_image(input_path: str, data_resolution: List[int]) -> dict: + """ + Function to read input image. + Args: + input_path (str): Path to input image + data_resolution (List[int]): Data resolution + Returns: + Dict containing input image + """ + flist = [input_path] + return load_image_from_list(flist, data_resolution=data_resolution) + + +def read_input_videos(batch_input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: + r""" + Function to read input videos. + Args: + batch_input_path (str): Path to JSONL file containing visual input paths + data_resolution (list[int]): Data resolution + Returns: + Dict containing input videos + """ + # Read visual inputs from JSONL + flist = [] + with open(batch_input_path, "r") as f: + for line in f: + data = json.loads(line.strip()) + flist.append(data["visual_input"]) + return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) + + +def read_input_video(input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: + """ + Function to read input video. + Args: + input_path (str): Path to input video + data_resolution (List[int]): Data resolution + num_input_frames (int): Number of frames in context + Returns: + Dict containing input video + """ + flist = [input_path] + return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) + + +def load_videos_from_list(flist: List[str], data_resolution: List[int], num_input_frames: int) -> dict: + """ + Function to load videos from a list of video paths. + Args: + flist (List[str]): List of video paths + data_resolution (List[int]): Data resolution + num_input_frames (int): Number of frames in context + Returns: + Dict containing input videos + """ + all_videos = dict() + + for video_path in flist: + ext = os.path.splitext(video_path)[-1] + if ext in _VIDEO_EXTENSIONS: + video, _, _ = torchvision.io.read_video(video_path, pts_unit="sec") + video = video.float() / 255.0 + video = video * 2 - 1 + + # Resize the videos to the required dimension + nframes_in_video = video.shape[0] + if nframes_in_video < num_input_frames: + fname = os.path.basename(video_path) + log.warning( + f"Video {fname} has {nframes_in_video} frames, less than the requried {num_input_frames} frames. Skipping." + ) + continue + + video = video[-num_input_frames:, :, :, :] + + # Pad the video to NUM_TOTAL_FRAMES (because the tokenizer expects inputs of NUM_TOTAL_FRAMES) + video = torch.cat( + (video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_TOTAL_FRAMES - num_input_frames, 1, 1, 1)), + dim=0, + ) + + video = video.permute(0, 3, 1, 2) + + log.debug( + f"Resizing input video of shape ({video.shape[2]}, {video.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" + ) + video = resize_input(video, data_resolution) + + fname = os.path.basename(video_path) + all_videos[fname] = video.transpose(0, 1).unsqueeze(0) + + return all_videos + + +def load_vision_input( + input_type: str, + batch_input_path: str, + input_image_or_video_path: str, + data_resolution: List[int], + num_input_frames: int, +): + """ + Function to load vision input. + Note: We pad the frames of the input image/video to NUM_TOTAL_FRAMES here, and feed the padded video tensors to the video tokenizer to obtain tokens. The tokens will be truncated based on num_input_frames when feeding to the autoregressive model. + Args: + input_type (str): Type of input + batch_input_path (str): Folder containing input images or videos + input_image_or_video_path (str): Path to input image or video + data_resolution (List[int]): Data resolution + num_input_frames (int): Number of frames in context + Returns: + Dict containing input videos + """ + if batch_input_path: + log.info(f"Reading batch inputs from path: {batch_input_path}") + if input_type == "image" or input_type == "text_and_image": + input_videos = read_input_images(batch_input_path, data_resolution=data_resolution) + elif input_type == "video" or input_type == "text_and_video": + input_videos = read_input_videos( + batch_input_path, + data_resolution=data_resolution, + num_input_frames=num_input_frames, + ) + else: + raise ValueError(f"Invalid input type {input_type}") + else: + if input_type == "image" or input_type == "text_and_image": + input_videos = read_input_image(input_image_or_video_path, data_resolution=data_resolution) + elif input_type == "video" or input_type == "text_and_video": + input_videos = read_input_video( + input_image_or_video_path, + data_resolution=data_resolution, + num_input_frames=num_input_frames, + ) + else: + raise ValueError(f"Invalid input type {input_type}") + return input_videos + + +def prepare_video_batch_for_saving(video_batch: List[torch.Tensor]) -> List[np.ndarray]: + """ + Function to convert output tensors to numpy format for saving. + Args: + video_batch (List[torch.Tensor]): List of output tensors + Returns: + List of numpy arrays + """ + return [(video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() for video in video_batch] diff --git a/cosmos_predict1/autoregressive/utils/misc.py b/cosmos_predict1/autoregressive/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8e781bcf3538bd26b2aa45adce8a2921b9a14f --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/misc.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from omegaconf import DictConfig, OmegaConf + + +class CustomSimpleNamespace: + """ + A simple namespace class that supports both attribute-style and dictionary-style access. + """ + + def __init__(self, d): + self._d = d + + def __getattr__(self, attr): + # Attribute-style access: config.key + try: + return self._d[attr] + except KeyError: + raise AttributeError(f"'CustomSimpleNamespace' object has no attribute '{attr}'") + + def __getitem__(self, key): + # Dictionary-style access: config['key'] + return self._d[key] + + +def maybe_convert_to_namespace(config): + """ + This function cast a OmegaConf's DictConfig or a standard dict to CustomSimpleNamespace, which supports both + attribute-style and dictionary-style access. + Note: We need to convert OmegaConf's DictConfig since it is not compatible with torch.compile. + """ + # If input is OmegaConf's DictConfig, convert to a standard dict + if isinstance(config, DictConfig): + config = OmegaConf.to_container(config, resolve=True) + + if isinstance(config, dict): + return CustomSimpleNamespace(config) + else: + return config + + +def random_dropout(embeddings, drop_rate): + r""" + Function to perform random dropout for embeddings. + When we drop embeddings, we zero them out. + Args: + embeddings (tensor): Input embeddings + drop_rate (float): Rate of dropping the embedding. + """ + num_samples = embeddings.shape[0] + # Create a shape (num_samples, 1, 1, 1, 1, ...) depending on embeddings dim. + # This is done to ensure we can broadcast the zero_flag to the embeddings. + # embeddings.ndim is 3 for images, and 4 for videos, and the corresponding + # shapes are (num_samples, 1, 1) and (num_samples, 1, 1, 1) respectively. + tensor_shape = (num_samples,) + tuple([1] * (embeddings.ndim - 1)) + zero_flag = torch.ones(tensor_shape).to(embeddings.dtype) * (1 - drop_rate) + zero_flag = torch.bernoulli(zero_flag).to(embeddings.device) + embeddings = embeddings * zero_flag + return embeddings diff --git a/cosmos_predict1/autoregressive/utils/parallel.py b/cosmos_predict1/autoregressive/utils/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..05f7733aea75175eae7d3b1b68d2b004b377a6c8 --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/parallel.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +import torch.distributed as dist +from megatron.core import mpu, parallel_state +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.autograd import Function +from torch.distributed import broadcast, get_process_group_ranks +from transformer_engine.pytorch.jit import no_torch_dynamo +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE +from transformer_engine.pytorch.module.rmsnorm import _RMSNorm + +from cosmos_predict1.utils import log + + +def get_batch_on_this_cp_rank(inputs): + """Slice batch input along sequence dimension into multiple chunks, + which are parallelized across GPUs in a context parallel group. + """ + + # With causal masking, each token only attends to its prior tokens. Simply split + # sequence into CP chunks can result in severe load imbalance. That's to say, chunks + # at the end of sequence have bigger workload than others. To address this issue, + # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 + # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so + # that we can get balanced workload among GPUs in a context parallel group. + cp_size = parallel_state.get_context_parallel_world_size() + + if cp_size > 1: + cp_rank = mpu.get_context_parallel_rank() + seq_dim = 1 # if key != 'attention_mask' else 2 + inputs = inputs.view( + *inputs.shape[0:seq_dim], + 2 * cp_size, + inputs.shape[seq_dim] // (2 * cp_size), + *inputs.shape[(seq_dim + 1) :], + ) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( + non_blocking=True + ) + inputs = inputs.index_select(seq_dim, index) + inputs = inputs.view(*inputs.shape[0:seq_dim], -1, *inputs.shape[(seq_dim + 2) :]) + + return inputs + + +def gather_batch_from_cp_ranks(outputs): + """ + Gather and reconstruct the full batch from chunks distributed across GPUs in a context parallel group. + """ + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + + if cp_size > 1: + seq_dim = 1 # Assuming sequence dimension is 1 + + try: + # Reshape output to separate the two chunks + chunk_size = outputs.shape[seq_dim] // 2 + outputs = outputs.view(*outputs.shape[:seq_dim], 2, chunk_size, *outputs.shape[seq_dim + 1 :]) + + # Prepare a list to gather all chunks from all ranks + gathered_chunks = [torch.zeros_like(outputs) for _ in range(cp_size)] + + # Gather all chunks + dist.barrier() + dist.all_gather(gathered_chunks, outputs, group=parallel_state.get_context_parallel_group()) + dist.barrier() + + # Reorder chunks + reordered_chunks = [None] * (2 * cp_size) + for i in range(cp_size): + reordered_chunks[i] = gathered_chunks[i].select(seq_dim, 0) + reordered_chunks[2 * cp_size - 1 - i] = gathered_chunks[i].select(seq_dim, 1) + + # Concatenate all chunks + outputs = torch.cat(reordered_chunks, dim=seq_dim) + except Exception as e: + log.info(f"[Rank {cp_rank}] Error in gather_batch_from_cp_ranks: {str(e)}") + raise + + return outputs + + +def broadcast_data_batch_in_tp_cp_group(data_batch): + """ + Broadcast data batch across tensor model parallel and context parallel groups. + """ + keys = sorted(data_batch.keys()) + tp_size = parallel_state.get_tensor_model_parallel_world_size() + cp_size = parallel_state.get_context_parallel_world_size() + tp_group = parallel_state.get_tensor_model_parallel_group() if tp_size > 1 else None + cp_group = parallel_state.get_context_parallel_group() if cp_size > 1 else None + tp_ranks = get_process_group_ranks(tp_group) if tp_size > 1 else None + cp_ranks = get_process_group_ranks(cp_group) if cp_size > 1 else None + if tp_size > 1 or cp_size > 1: + for key in keys: + tensor = data_batch[key] + if isinstance(tensor, torch.Tensor): + tensor = tensor.contiguous() + if tp_size > 1: + broadcast(tensor, min(tp_ranks), group=tp_group) + if cp_size > 1: + broadcast(tensor, min(cp_ranks), group=cp_group) + + +def allreduce_layernorm_grads(model: List[torch.nn.Module], tensor_model_parallel_size: int, sequence_parallel: bool): + """ + All-reduce layernorm grads (for sequence parallelism). + Note: + - We skip QK Normalization layers and the last normalization layer of Transformer, + since we use AllReduceBWDRMSNormTE for these layers, which already applies all-reduce in the backward pass. + - TransformerEngine's LayernormLinear and LayernormMLP modules have `*.layer_norm_weight` parameters that + we must all-reduce in the backward pass as well. So we implement this function to cover these parameters. + """ + # All-reduce layernorm parameters across model parallel nodes + # when sequence parallelism is used + if tensor_model_parallel_size > 1 and sequence_parallel: + grads = [] + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + if not param.requires_grad: + continue + if name.endswith(".layer_norm_weight"): # TP # Q-layernorm # K-layernorm + grad = param.grad + if grad is not None: + grads.append(grad.data) + + if grads: + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + +def sync_1d_parameters(model: torch.nn.Module, process_group=None): + """ + Synchronize layernorm parameters (1D) across ranks by performing all-reduce with mean operation. + LayerNorm parameters are identified by having ndim==1. + Note: If parameters other than LayerNorm are 1D, they will also be synchronized. + + Args: + model (torch.nn.Module): The model containing layernorm parameters + process_group (optional): The process group to perform all-reduce. + If None, uses the default process group. + """ + if not torch.distributed.is_initialized(): + return + # Synchronize each 1D parameter (layernorm parameters) + for name, param in model.named_parameters(): + if param.ndim == 1 and param.requires_grad: # LayerNorm weights/biases are 1D + torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.AVG, group=process_group) + + +class AllReduceBWD(Function): + """ + Custom autograd Function that performs an all-reduce operation during the backward pass. + + Args: + tensor (Tensor): The input tensor. + process_group: The process group to perform the all-reduce operation. + + Returns: + Tensor: The input tensor in the forward pass, and the all-reduced gradient in the backward pass. + """ + + @staticmethod + def forward(ctx, tensor, process_group): + ctx.process_group = process_group + return tensor + + @staticmethod + def backward(ctx, grad_output): + dist.all_reduce(grad_output, group=ctx.process_group) + return grad_output, None + + +class AllReduceBWDRMSNormTE(RMSNormTE): + """ + A custom RMSNorm layer that applies all-reduce operation during backward pass. + Used in tensor parallel training with Transformer Engine. + + Args: + hidden_size (int): The size of the hidden dimension. + process_group: Megatron Core's process group. + **kwargs: Additional arguments to be passed to RMSNormTE. + """ + + def __init__(self, hidden_size, process_group, **kwargs): + super().__init__(hidden_size, **kwargs) + self.process_group = process_group + + @no_torch_dynamo() + def forward(self, inp: torch.Tensor) -> torch.Tensor: + """RMSNorm FWD""" + + # Set the activation type for AMP. + TransformerEngineBaseModule.set_activation_dtype(self, inp) + + if torch.is_grad_enabled(): + fwd_fn = _RMSNorm.apply + args = [] + else: + fwd_fn = _RMSNorm.forward + args = [None] + + args += ( + inp, + AllReduceBWD.apply(self.weight, self.process_group), + self.eps, + self.fwd_rmsnorm_sm_margin, + self.bwd_rmsnorm_sm_margin, + self.inf_rmsnorm_sm_margin, + self.zero_centered_gamma, + torch.is_grad_enabled(), + self.activation_dtype, + ) + + return fwd_fn(*args) diff --git a/cosmos_predict1/autoregressive/utils/sampling.py b/cosmos_predict1/autoregressive/utils/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..46d364c3e4b3f84f2f29baa4551f8e17a2f601ad --- /dev/null +++ b/cosmos_predict1/autoregressive/utils/sampling.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from torch.nn.attention import SDPBackend, sdpa_kernel + +from cosmos_predict1.autoregressive.networks.transformer import Transformer + + +def sample_top_p(logits, temperature, top_p, return_probs: bool = False): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + logits (torch.Tensor): Logits of the probability distribution. + temperature (float): Temperature for sampling. + top_p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1) + # Sort the probabilities in descending order and get their indices. + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + # Compute the cumulative sum of the sorted probabilities. + probs_sum = torch.cumsum(probs_sort, dim=-1) + # Create a mask where the cumulative probability exceeds the threshold p. + mask = probs_sum - probs_sort > top_p + # Set the probabilities that exceed the threshold to 0. + probs_sort[mask] = 0.0 + # Renormalize the remaining probabilities so they sum to 1. + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + # Sample from the renormalized probability distribution. + # next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64) + # Gather the indices of the sampled tokens. + next_token = torch.gather(probs_idx, -1, next_token) + if return_probs: + # Initialize a tensor for unsorted probabilities + probs_unsorted = torch.zeros_like(probs_sort) + # Scatter the sorted probabilities back to their original order + probs_unsorted.scatter_(-1, probs_idx, probs_sort) + else: + probs_unsorted = None + return next_token, probs_unsorted + + +def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int): + """ + Multinomial sampling without a cuda synchronization. + Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py + """ + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype) + + +def logits_to_probs( + logits, + temperature: float = 1.0, + top_k: Optional[int] = None, +): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None): + """ + Sample from the logits using top-k sampling. + Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py + """ + # logits: [batch_size, seq_len, vocab_size] + if temperature == 0.0: + idx_next = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) + probs = None + else: + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def prefill( + model: Transformer, + input_pos: torch.Tensor, + tokens: torch.Tensor = None, + token_embeddings: torch.Tensor = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, +) -> torch.Tensor: + logits = model(tokens=tokens, token_embeddings=token_embeddings, input_pos=input_pos, **kwargs) + # Only top-p or top-k can be provided + assert ( + top_p is None or top_k is None + ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" + if top_p is not None: + return sample_top_p(logits, temperature=temperature, top_p=top_p)[0] + else: + return sample_top_k(logits, temperature=temperature, top_k=top_k)[0] + + +def decode_one_token( + model: Transformer, + tokens: torch.Tensor, + input_pos: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Decode a single token from the autoregressive model. + """ + logits = model(tokens=tokens, input_pos=input_pos, **kwargs) + if top_p is not None: + return sample_top_p(logits, temperature=temperature, top_p=top_p) + else: + return sample_top_k(logits, temperature=temperature, top_k=top_k) + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + stop_tokens: torch.Tensor = None, + temperature: float = 1.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + return_probs: bool = False, + decode_one_token_function=decode_one_token, + **kwargs, +): + """ + Decode n tokens from the autoregressive model. + Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py + """ + new_tokens, new_probs = [], [] + batch_size = cur_token.shape[0] + assert ( + top_p is None or top_k is None + ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" + if stop_tokens is not None: + # Indicator for whether the EOS token (stop token) has been reached for each sample in the batch + eos_reached = torch.tensor([False] * batch_size, device="cuda") + for t in range(num_new_tokens): + with sdpa_kernel([SDPBackend.MATH]): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token_function( + model, + tokens=cur_token, + input_pos=input_pos, + temperature=temperature, + top_k=top_k, + top_p=top_p, + **kwargs, + ) + input_pos += 1 + if stop_tokens is not None and len(stop_tokens) > 0: + eos_reached = eos_reached | (torch.isin(next_token, stop_tokens)) + if eos_reached.all(): + break + new_tokens.append(next_token.clone()) + if return_probs: + new_probs.append(next_prob.clone()) + cur_token = next_token.clone() + + if return_probs: + return new_tokens, new_probs + else: + return new_tokens diff --git a/cosmos_predict1/auxiliary/guardrail/__init__.py b/cosmos_predict1/auxiliary/guardrail/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/aegis/__init__.py b/cosmos_predict1/auxiliary/guardrail/aegis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/aegis/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/aegis/aegis.py b/cosmos_predict1/auxiliary/guardrail/aegis/aegis.py new file mode 100644 index 0000000000000000000000000000000000000000..ade12a7cf8588768511aaa6f282850c0c2252d1d --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/aegis/aegis.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + +from cosmos_predict1.auxiliary.guardrail.aegis.categories import UNSAFE_CATEGORIES +from cosmos_predict1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner +from cosmos_predict1.utils import log, misc + +SAFE = misc.Color.green("SAFE") +UNSAFE = misc.Color.red("UNSAFE") + + +class Aegis(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + self.checkpoint_dir = checkpoint_dir + self.device = device + self.dtype = torch.bfloat16 + + base_model_id = "meta-llama/LlamaGuard-7b" + aegis_adapter = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0" + base_model_dir = os.path.join(self.checkpoint_dir, base_model_id) + aegis_adapter_dir = os.path.join(self.checkpoint_dir, aegis_adapter) + + base_model = AutoModelForCausalLM.from_pretrained(base_model_dir) + self.tokenizer = AutoTokenizer.from_pretrained(base_model_dir) + self.model = PeftModel.from_pretrained(base_model, aegis_adapter_dir) + + self.model.to(self.device, dtype=self.dtype).eval() + + def get_moderation_prompt(self, user_prompt: str) -> str: + """Create the moderation prompt for the Aegis model.""" + unsafe_categories = "".join([f"{category['instruction']}" for category in UNSAFE_CATEGORIES]) + full_prompt = ( + "[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.\n" + "\n" + f"{unsafe_categories}\n" + "\n" + "\n" + ) + full_prompt += f"User: {user_prompt}\n\n" + full_prompt += """ + + Provide your safety assessment for User in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" + + return full_prompt + + def get_aegis_block_message(self, moderation_output: str) -> str: + """Extract the blocked category and reason from the Aegis model output.""" + block_msg = "Prompt blocked by Aegis." + try: + # Extract blocked category and reason + if len(moderation_output.split("\n")) >= 2: + moderation_output_list = moderation_output.split("\n") + blocked_category = int(moderation_output_list[1].lstrip("O")) - 1 + if 0 <= blocked_category < len(UNSAFE_CATEGORIES): + blocked_reason = UNSAFE_CATEGORIES[blocked_category]["blocked_reason"] + blocked_category_name = UNSAFE_CATEGORIES[blocked_category]["category"] + block_msg = f"{blocked_category_name}: {blocked_reason}" + except Exception as e: + log.warning(f"Unable to extract blocked category and reason from Aegis output: {e}") + return block_msg + + def filter_aegis_output(self, prompt: str) -> tuple[bool, str]: + """Filter the Aegis model output and return the safety status and message.""" + full_prompt = self.get_moderation_prompt(prompt) + inputs = self.tokenizer([full_prompt], add_special_tokens=False, return_tensors="pt").to(self.device) + output = self.model.generate(**inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id) + prompt_len = inputs["input_ids"].shape[-1] + moderation_output = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True) + + if "unsafe" in moderation_output.lower(): + block_msg = self.get_aegis_block_message(moderation_output) + return False, block_msg + else: + return True, "" + + def is_safe(self, prompt: str) -> tuple[bool, str]: + """Check if the input prompt is safe according to the Aegis model.""" + try: + return self.filter_aegis_output(prompt) + except Exception as e: + log.error(f"Unexpected error occurred when running Aegis guardrail: {e}") + return True, "Unexpected error occurred when running Aegis guardrail." + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", type=str, required=True, help="Input prompt") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Aegis checkpoint folder", + ) + return parser.parse_args() + + +def main(args): + aegis = Aegis(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[aegis]) + with misc.timer("aegis safety check"): + safety, message = runner.run_safety_check(args.prompt) + log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") + log.info(f"Message: {message}") if not safety else None + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/auxiliary/guardrail/aegis/categories.py b/cosmos_predict1/auxiliary/guardrail/aegis/categories.py new file mode 100644 index 0000000000000000000000000000000000000000..f9236da33c8d2669b701cc21709925bacbbe6907 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/aegis/categories.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +VIOLENCE = { + "category": "violence", + "blocked_reason": """This prompt is blocked due to the presence of violent content.""", + "instruction": """O1: Violence. +Should not +- Help people plan or engage in violence. +- Encourage people to engage in violence. +- Contain unpleasant vivid violent imagery or violent description that arouses discomfort in the reader. +Can +- Provide information on violence. +- Discuss the topics of violence, and can discuss historical events involving violence. +""", +} + +SEXUAL = { + "category": "sexual", + "blocked_reason": """This prompt is blocked due to the presence of sexual content.""", + "instruction": """O2: Sexual. +Should not +- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. +- Engage or encourage sexual activity that is considered illegal. +Can +- Discuss topics related to sexuality and sexual education. +- Provide general information regarding sexual activities. +- Provide guidance on topics related to sexual health. +""", +} + +CRIMINAL_PLANNING = { + "category": "criminal_planning", + "blocked_reason": """This prompt is blocked due to the presence of criminal planning/confessions content.""", + "instruction": """O3: Criminal Planning/Confessions. +Should not +- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. +- Encourage people to engage in criminal activities or condone criminal behavior. +- Have a confession of crime that was committed or will be committed in future. +Can +- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. +- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems. +""", +} + +WEAPONS = { + "category": "weapons", + "blocked_reason": """This prompt is blocked due to the presence of weapons content.""", + "instruction": """O4: Guns and Illegal Weapons. +Should not +- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. +- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. +- Assist or encourage people to execute or own other illegal weapons for example biological agents, or chemical weapons, including instructions on how to create or use them. +Can +- Discuss firearms and the arguments for and against firearm ownership. +""", +} + +SUBSTANCE_ABUSE = { + "category": "substance_abuse", + "blocked_reason": """This prompt is blocked due to the presence of substance abuse content.""", + "instruction": """O5: Controlled/Regulated Substances. +Should not +- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. +- Assist or encourage people to create such substances when it is illegal to do so. +Can +- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). +- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation. +""", +} + +SUICIDE = { + "category": "suicide", + "blocked_reason": "This prompt is blocked due to the presence of suicide and self harm content.", + "instruction": """O6: Suicide and Self Harm. +Should not +- Encourage people to harm themselves. +- Romanticize or condone self-harm or suicide. +- Provide information regarding the methods of suicide or self-harm. +- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. +- Contain confessions or intent of self harm +Can +- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). +Should +- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves. +""", +} + +CSAM = { + "category": "csam", + "blocked_reason": "This prompt contains material related to Child Sexual Abuse", + "instruction": """O7: Sexual (minor). +Should not +- Assist, engage or encourage intention to engage in sexual activity with an underage/minor. +- Discuss or plan sexual acts, sharing of media of sexual nature, personal information or intention to meet in person for sexual activity with a minor. +- Should not include discussion where private information is shared with underage minors for purposes of acting on them and taking actions that are sexual in nature including meeting or sharing media. +- Should not discuss child abuse or engage, encourage, assist in sexual grooming/trafficking of children. +Can +- Provide information about sex education for use in schools. +""", +} + +HATRED = { + "category": "hatred", + "blocked_reason": "This prompt is blocked due to the presence of hate speech content.", + "instruction": """O8: Hate/Identity Hate. +Should not +- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their +race, color, ethnicity, religion, national origin, sexual orientation, gender, gender identity, or disability or caste. +- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. +- Should not discuss prejudice towards a group classified as protected according to the law of Germany, UK, USA or any other country. +- Should not contain slurs. +- Should not contain text that is distasteful, offensive, or hateful that causes others to feel grief, anger, or fear. +Can +- Provide information on discrimination. +- Discuss the topics of discrimination, and can discuss historical events involving discrimination. +- Discuss neutral statements referring to group identity and contextualized hate speech (that talks about hate but does not incite hate). +""", +} + +HARASSMENT = { + "category": "harassment", + "blocked_reason": "This prompt is blocked due to the presence of harassment content.", + "instruction": """O10: Harassment. +Should not +- Have content that may be used to torment or annoy individuals in real life, or make harassment more likely to occur. +- Engage in harassment, misconduct and unlawful discrimination +- Single out an individual or marginalization and/or retaliation based on protected characteristics like Race, Color, Gender, Sex, +Sexual orientation, Gender identity, and gender expression, National origin, Ethnicity, Disability (including being regarded as disabled), +Religion, Age (40+), Pregnancy (including pregnancy, childbirth or related medical conditions), Genetic information, Military or veteran status, Citizenship status, Political activity, or affiliation, +- Taking or requesting statutorily protected leave Body characteristics. +- Promise to give a benefit, or a threat to retaliate or take an adverse action based on a response to harassment and misconduct. +- Bullying. +- Being derogatory and humiliating toward an individual but not necessarily protected characteristics under law. +- Making rude or insulting comments, demeaning, and objectifying terms toward an individual. +Can +- Discuss negative impacts of harassment and bullying based on protected characteristics like race, color, gender, sex, nationality, within a contextualized setting. +- Provide information on protected characteristics or refers, cites material that describes what protected characteristics are. +- Provide information or refer or cite material that informs what harassment is. +""", +} + +THREAT = { + "category": "threat", + "blocked_reason": "This prompt is blocked due to the presence of threat content.", + "instruction": """O11: Threat. +Should not +- Discuss a communication of intent to inflict harm or loss on another person. +- Should not contain discussion of intention of intimidation which is a tactic used between conflicting parties to make the other timid or psychologically insecure for coercion or control. +- Should not contain intention of knowingly putting another person in fear of bodily injury. +- Should not assist or plan injury, physical or mental damage, act or instance of injury, or material or any kind of loss to a person. +- Should not contain threats that are forbidden by law as intent to obtain a monetary advantage or to compel a person to act against their will. +Can +- Provide information about threat is and the meaning of threat. +""", +} + +PROFANITY = { + "category": "profanity", + "blocked_reason": "This prompt is blocked due to the presence of profanity content.", + "instruction": """O12: Profanity. +Should not +- Contain swear words, curse words, or other obscene or profane language +- This includes offensive words used without any intention to act on them or without any explicit targeted recipient +""", +} + +UNSAFE_CATEGORIES = [ + VIOLENCE, + SEXUAL, + CRIMINAL_PLANNING, + WEAPONS, + SUBSTANCE_ABUSE, + SUICIDE, + CSAM, + HATRED, + HARASSMENT, + THREAT, + PROFANITY, +] diff --git a/cosmos_predict1/auxiliary/guardrail/blocklist/__init__.py b/cosmos_predict1/auxiliary/guardrail/blocklist/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/blocklist/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/blocklist/blocklist.py b/cosmos_predict1/auxiliary/guardrail/blocklist/blocklist.py new file mode 100644 index 0000000000000000000000000000000000000000..d0fa7aafc609e69677f2ffb319fa11eb9a689fe5 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/blocklist/blocklist.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re +import string +from difflib import SequenceMatcher + +import nltk +from better_profanity import profanity + +from cosmos_predict1.auxiliary.guardrail.blocklist.utils import read_keyword_list_from_dir, to_ascii +from cosmos_predict1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner +from cosmos_predict1.utils import log, misc + +CENSOR = misc.Color.red("*") + + +class Blocklist(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str, + guardrail_partial_match_min_chars: int = 6, + guardrail_partial_match_letter_count: float = 0.4, + ) -> None: + self.checkpoint_dir = os.path.join(checkpoint_dir, "nvidia/Cosmos-Guardrail1/blocklist") + nltk.data.path.append(os.path.join(self.checkpoint_dir, "nltk_data")) + self.lemmatizer = nltk.WordNetLemmatizer() + self.profanity = profanity + self.guardrail_partial_match_min_chars = guardrail_partial_match_min_chars + self.guardrail_partial_match_letter_count = guardrail_partial_match_letter_count + + # Load blocklist and whitelist keywords + self.blocklist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "custom")) + self.whitelist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "whitelist")) + self.exact_match_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "exact_match")) + + self.profanity.load_censor_words(custom_words=self.blocklist_words, whitelist_words=self.whitelist_words) + log.debug(f"Loaded {len(self.blocklist_words)} words/phrases from blocklist") + log.debug(f"Whitelisted {len(self.whitelist_words)} words/phrases from whitelist") + log.debug(f"Loaded {len(self.exact_match_words)} exact match words/phrases from blocklist") + + def uncensor_whitelist(self, input_prompt: str, censored_prompt: str) -> str: + """Explicitly uncensor words that are in the whitelist.""" + input_words = input_prompt.split() + censored_words = censored_prompt.split() + whitelist_words = set(self.whitelist_words) + for i, token in enumerate(input_words): + if token.strip(string.punctuation).lower() in whitelist_words: + censored_words[i] = token + censored_prompt = " ".join(censored_words) + return censored_prompt + + def censor_prompt(self, input_prompt: str) -> tuple[bool, str]: + """Censor the prompt using the blocklist with better-profanity fuzzy matching. + + Args: + input_prompt: input prompt to censor + + Returns: + bool: True if the prompt is blocked, False otherwise + str: A message indicating why the prompt was blocked + """ + censored_prompt = self.profanity.censor(input_prompt, censor_char=CENSOR) + # Uncensor whitelisted words that were censored from blocklist fuzzy matching + censored_prompt = self.uncensor_whitelist(input_prompt, censored_prompt) + if CENSOR in censored_prompt: + return True, f"Prompt blocked by censorship: Censored Prompt: {censored_prompt}" + return False, "" + + @staticmethod + def check_partial_match( + normalized_prompt: str, normalized_word: str, guardrail_partial_match_letter_count: float + ) -> tuple[bool, str]: + """ + Check robustly if normalized word and the matching target have a difference of up to guardrail_partial_match_letter_count characters. + + Args: + normalized_prompt: a string with many words + normalized_word: a string with one or multiple words, its length is smaller than normalized_prompt + guardrail_partial_match_letter_count: maximum allowed difference in characters (float to allow partial characters) + + Returns: + bool: True if a match is found, False otherwise + str: A message indicating why the prompt was blocked + """ + prompt_words = normalized_prompt.split() + word_length = len(normalized_word.split()) + max_similarity_ratio = (len(normalized_word) - float(guardrail_partial_match_letter_count)) / float( + len(normalized_word) + ) + + for i in range(len(prompt_words) - word_length + 1): + # Extract a substring from the prompt with the same number of words as the normalized_word + substring = " ".join(prompt_words[i : i + word_length]) + similarity_ratio = SequenceMatcher(None, substring, normalized_word).ratio() + if similarity_ratio >= max_similarity_ratio: + return ( + True, + f"Prompt blocked by partial match blocklist: Prompt: {normalized_prompt}, Partial Match Word: {normalized_word}", + ) + + return False, "" + + @staticmethod + def check_against_whole_word_blocklist( + prompt: str, + blocklist: list[str], + guardrail_partial_match_min_chars: int = 6, + guardrail_partial_match_letter_count: float = 0.4, + ) -> bool: + """ + Check if the prompt contains any whole words from the blocklist. + The match is case insensitive and robust to multiple spaces between words. + + Args: + prompt: input prompt to check + blocklist: list of words to check against + guardrail_partial_match_min_chars: minimum number of characters in a word to check for partial match + guardrail_partial_match_letter_count: maximum allowed difference in characters for partial match + + Returns: + bool: True if a match is found, False otherwise + str: A message indicating why the prompt was blocked + """ + # Normalize spaces and convert to lowercase + normalized_prompt = re.sub(r"\s+", " ", prompt).strip().lower() + + for word in blocklist: + # Normalize spaces and convert to lowercase for each blocklist word + normalized_word = re.sub(r"\s+", " ", word).strip().lower() + + # Use word boundaries to ensure whole word match + if re.search(r"\b" + re.escape(normalized_word) + r"\b", normalized_prompt): + return True, f"Prompt blocked by exact match blocklist: Prompt: {prompt}, Exact Match Word: {word}" + + # Check for partial match if the word is long enough + if len(normalized_word) >= guardrail_partial_match_min_chars: + match, message = Blocklist.check_partial_match( + normalized_prompt, normalized_word, guardrail_partial_match_letter_count + ) + if match: + return True, message + + return False, "" + + def is_safe(self, input_prompt: str = "") -> tuple[bool, str]: + """Check if the input prompt is safe using the blocklist.""" + # Check if the input is empty + if not input_prompt: + return False, "Input is empty" + input_prompt = to_ascii(input_prompt) + + # Check full sentence for censored words + censored, message = self.censor_prompt(input_prompt) + if censored: + return False, message + + # Check lemmatized words for censored words + tokens = nltk.word_tokenize(input_prompt) + lemmas = [self.lemmatizer.lemmatize(token) for token in tokens] + lemmatized_prompt = " ".join(lemmas) + censored, message = self.censor_prompt(lemmatized_prompt) + if censored: + return False, message + + # Check for exact match blocklist words + censored, message = self.check_against_whole_word_blocklist( + input_prompt, + self.exact_match_words, + self.guardrail_partial_match_min_chars, + self.guardrail_partial_match_letter_count, + ) + if censored: + return False, message + + # If all these checks pass, the input is safe + return True, "Input is safe" + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", type=str, required=True, help="Input prompt") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Blocklist checkpoint folder", + ) + return parser.parse_args() + + +def main(args): + blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[blocklist]) + with misc.timer("blocklist safety check"): + safety, message = runner.run_safety_check(args.prompt) + log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") + log.info(f"Message: {message}") if not safety else None + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/auxiliary/guardrail/blocklist/utils.py b/cosmos_predict1/auxiliary/guardrail/blocklist/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8555af872b03dd8a9dad0dd2699550bdcdd5b1 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/blocklist/utils.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +from cosmos_predict1.utils import log + + +def read_keyword_list_from_dir(folder_path: str) -> list[str]: + """Read keyword list from all files in a folder.""" + output_list = [] + file_list = [] + # Get list of files in the folder + for file in os.listdir(folder_path): + if os.path.isfile(os.path.join(folder_path, file)): + file_list.append(file) + + # Process each file + for file in file_list: + file_path = os.path.join(folder_path, file) + try: + with open(file_path, "r") as f: + output_list.extend([line.strip() for line in f.readlines()]) + except Exception as e: + log.error(f"Error reading file {file}: {str(e)}") + + return output_list + + +def to_ascii(prompt: str) -> str: + """Convert prompt to ASCII.""" + return re.sub(r"[^\x00-\x7F]+", " ", prompt) diff --git a/cosmos_predict1/auxiliary/guardrail/common/__init__.py b/cosmos_predict1/auxiliary/guardrail/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/auxiliary/guardrail/common/core.py b/cosmos_predict1/auxiliary/guardrail/common/core.py new file mode 100644 index 0000000000000000000000000000000000000000..f4deeaa2ca0eb99d8b778665963221365ad3927d --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/common/core.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple + +import numpy as np + +from cosmos_predict1.utils import log + + +class ContentSafetyGuardrail: + def is_safe(self, **kwargs) -> Tuple[bool, str]: + raise NotImplementedError("Child classes must implement the is_safe method") + + +class PostprocessingGuardrail: + def postprocess(self, frames: np.ndarray) -> np.ndarray: + raise NotImplementedError("Child classes must implement the postprocess method") + + +class GuardrailRunner: + def __init__( + self, + safety_models: list[ContentSafetyGuardrail] | None = None, + generic_block_msg: str = "", + generic_safe_msg: str = "", + postprocessors: list[PostprocessingGuardrail] | None = None, + ): + self.safety_models = safety_models + self.generic_block_msg = generic_block_msg + self.generic_safe_msg = generic_safe_msg if generic_safe_msg else "Prompt is safe" + self.postprocessors = postprocessors + + def run_safety_check(self, input: Any) -> Tuple[bool, str]: + """Run the safety check on the input.""" + if not self.safety_models: + log.warning("No safety models found, returning safe") + return True, self.generic_safe_msg + + for guardrail in self.safety_models: + guardrail_name = str(guardrail.__class__.__name__).upper() + log.debug(f"Running guardrail: {guardrail_name}") + safe, message = guardrail.is_safe(input) + if not safe: + reasoning = self.generic_block_msg if self.generic_block_msg else f"{guardrail_name}: {message}" + return False, reasoning + return True, self.generic_safe_msg + + def postprocess(self, frames: np.ndarray) -> np.ndarray: + """Run the postprocessing on the video frames.""" + if not self.postprocessors: + log.warning("No postprocessors found, returning original frames") + return frames + + for guardrail in self.postprocessors: + guardrail_name = str(guardrail.__class__.__name__).upper() + log.debug(f"Running guardrail: {guardrail_name}") + frames = guardrail.postprocess(frames) + return frames diff --git a/cosmos_predict1/auxiliary/guardrail/common/io_utils.py b/cosmos_predict1/auxiliary/guardrail/common/io_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..129f049233191368a6dee4ef202088fdd851e3e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/common/io_utils.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +from dataclasses import dataclass + +import imageio +import numpy as np + +from cosmos_predict1.utils import log + + +@dataclass +class VideoData: + frames: np.ndarray # Shape: [B, H, W, C] + fps: int + duration: int # in seconds + + +def get_video_filepaths(input_dir: str) -> list[str]: + """Get a list of filepaths for all videos in the input directory.""" + paths = glob.glob(f"{input_dir}/**/*.mp4", recursive=True) + paths += glob.glob(f"{input_dir}/**/*.avi", recursive=True) + paths += glob.glob(f"{input_dir}/**/*.mov", recursive=True) + paths = sorted(paths) + log.debug(f"Found {len(paths)} videos") + return paths + + +def read_video(filepath: str) -> VideoData: + """Read a video file and extract its frames and metadata.""" + try: + reader = imageio.get_reader(filepath, "ffmpeg") + except Exception as e: + raise ValueError(f"Failed to read video file: {filepath}") from e + + # Extract metadata from the video file + try: + metadata = reader.get_meta_data() + fps = metadata.get("fps") + duration = metadata.get("duration") + except Exception as e: + reader.close() + raise ValueError(f"Failed to extract metadata from video file: {filepath}") from e + + # Extract frames from the video file + try: + frames = np.array([frame for frame in reader]) + except Exception as e: + raise ValueError(f"Failed to extract frames from video file: {filepath}") from e + finally: + reader.close() + + return VideoData(frames=frames, fps=fps, duration=duration) + + +def save_video(filepath: str, frames: np.ndarray, fps: int) -> None: + """Save a video file from a sequence of frames.""" + try: + writer = imageio.get_writer(filepath, fps=fps, macro_block_size=1) + for frame in frames: + writer.append_data(frame) + except Exception as e: + raise ValueError(f"Failed to save video file to {filepath}") from e + finally: + writer.close() diff --git a/cosmos_predict1/auxiliary/guardrail/common/presets.py b/cosmos_predict1/auxiliary/guardrail/common/presets.py new file mode 100644 index 0000000000000000000000000000000000000000..245fe445496e9c265742023f3c61ece7d82ee49e --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/common/presets.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path + +import numpy as np + +from cosmos_predict1.auxiliary.guardrail.blocklist.blocklist import Blocklist +from cosmos_predict1.auxiliary.guardrail.common.core import GuardrailRunner +from cosmos_predict1.auxiliary.guardrail.face_blur_filter.face_blur_filter import RetinaFaceFilter +from cosmos_predict1.auxiliary.guardrail.llamaGuard3.llamaGuard3 import LlamaGuard3 +from cosmos_predict1.auxiliary.guardrail.video_content_safety_filter.video_content_safety_filter import ( + VideoContentSafetyFilter, +) +from cosmos_predict1.utils import log + + +def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner: + """Create the text guardrail runner.""" + return GuardrailRunner(safety_models=[Blocklist(checkpoint_dir), LlamaGuard3(checkpoint_dir)]) + + +def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner: + """Create the video guardrail runner.""" + return GuardrailRunner( + safety_models=[VideoContentSafetyFilter(checkpoint_dir)], + postprocessors=[RetinaFaceFilter(checkpoint_dir)], + ) + + +def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool: + """Run the text guardrail on the prompt, checking for content safety. + + Args: + prompt: The text prompt. + guardrail_runner: The text guardrail runner. + + Returns: + bool: Whether the prompt is safe. + """ + is_safe, message = guardrail_runner.run_safety_check(prompt) + if not is_safe: + log.critical(f"GUARDRAIL BLOCKED: {message}") + return is_safe + + +def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None: + """Run the video guardrail on the frames, checking for content safety and applying face blur. + + Args: + frames: The frames of the generated video. + guardrail_runner: The video guardrail runner. + + Returns: + The processed frames if safe, otherwise None. + """ + is_safe, message = guardrail_runner.run_safety_check(frames) + if not is_safe: + log.critical(f"GUARDRAIL BLOCKED: {message}") + return None + + frames = guardrail_runner.postprocess(frames) + return frames diff --git a/cosmos_predict1/auxiliary/guardrail/face_blur_filter/__init__.py b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/face_blur_filter/blur_utils.py b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/blur_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d52f69d220444a53027b3b4acc3bd192fc6eb76f --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/blur_utils.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +def pixelate_face(face_img: np.ndarray, blocks: int = 5) -> np.ndarray: + """ + Pixelate a face region by reducing resolution and then upscaling. + + Args: + face_img: Face region to pixelate + blocks: Number of blocks to divide the face into (in each dimension) + + Returns: + Pixelated face region + """ + h, w = face_img.shape[:2] + # Shrink the image and scale back up to create pixelation effect + temp = cv2.resize(face_img, (blocks, blocks), interpolation=cv2.INTER_LINEAR) + pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST) + return pixelated diff --git a/cosmos_predict1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..eb26b9c4d0b0dca930336487c11f26373b6ab293 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import warnings + +import numpy as np +import torch +from retinaface.data import cfg_re50 +from retinaface.layers.functions.prior_box import PriorBox +from retinaface.models.retinaface import RetinaFace +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm + +from cosmos_predict1.auxiliary.guardrail.common.core import GuardrailRunner, PostprocessingGuardrail +from cosmos_predict1.auxiliary.guardrail.common.io_utils import get_video_filepaths, read_video, save_video +from cosmos_predict1.auxiliary.guardrail.face_blur_filter.blur_utils import pixelate_face +from cosmos_predict1.auxiliary.guardrail.face_blur_filter.retinaface_utils import ( + decode_batch, + filter_detected_boxes, + load_model, +) +from cosmos_predict1.utils import log, misc + +# RetinaFace model constants from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +TOP_K = 5_000 +KEEP_TOP_K = 750 +NMS_THRESHOLD = 0.4 + + +class RetinaFaceFilter(PostprocessingGuardrail): + def __init__( + self, + checkpoint_dir: str, + batch_size: int = 1, + confidence_threshold: float = 0.7, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + """ + Initialize the RetinaFace model for face detection and blurring. + + Args: + checkpoint: Path to the RetinaFace checkpoint file + batch_size: Batch size for RetinaFace inference and processing + confidence_threshold: Minimum confidence score to consider a face detection + """ + self.checkpoint = f"{checkpoint_dir}/nvidia/Cosmos-Guardrail1/face_blur_filter/Resnet50_Final.pth" + self.cfg = cfg_re50 + self.batch_size = batch_size + self.confidence_threshold = confidence_threshold + self.device = device + self.dtype = torch.float32 + + # Disable loading ResNet pretrained weights + self.cfg["pretrain"] = False + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.net = RetinaFace(cfg=self.cfg, phase="test") + cpu = self.device == "cpu" + + # Load from RetinaFace pretrained checkpoint + self.net = load_model(self.net, self.checkpoint, cpu) + self.net.to(self.device, dtype=self.dtype).eval() + + def preprocess_frames(self, frames: np.ndarray) -> torch.Tensor: + """Preprocess a sequence of frames for face detection. + + Args: + frames: Input frames + + Returns: + Preprocessed frames tensor + """ + with torch.no_grad(): + frames_tensor = torch.from_numpy(frames).to(self.device, dtype=self.dtype) # Shape: [T, H, W, C] + frames_tensor = frames_tensor.permute(0, 3, 1, 2) # Shape: [T, C, H, W] + frames_tensor = frames_tensor[:, [2, 1, 0], :, :] # RGB to BGR to match RetinaFace model input + means = torch.tensor([104.0, 117.0, 123.0], device=self.device, dtype=self.dtype).view(1, 3, 1, 1) + frames_tensor = frames_tensor - means # Subtract mean BGR values for each channel + return frames_tensor + + def blur_detected_faces( + self, + frames: np.ndarray, + batch_loc: torch.Tensor, + batch_conf: torch.Tensor, + prior_data: torch.Tensor, + scale: torch.Tensor, + min_size: tuple[int] = (20, 20), + ) -> list[np.ndarray]: + """Blur detected faces in a batch of frames using RetinaFace predictions. + + Args: + frames: Input frames + batch_loc: Batched location predictions + batch_conf: Batched confidence scores + prior_data: Prior boxes for the video + scale: Scale factor for resizing detections + min_size: Minimum size of a detected face region in pixels + + Returns: + Processed frames with pixelated faces + """ + with torch.no_grad(): + batch_boxes = decode_batch(batch_loc, prior_data, self.cfg["variance"]) + batch_boxes = batch_boxes * scale + + blurred_frames = [] + for i, boxes in enumerate(batch_boxes): + boxes = boxes.detach().cpu().numpy() + scores = batch_conf[i, :, 1].detach().cpu().numpy() + + filtered_boxes = filter_detected_boxes( + boxes, + scores, + confidence_threshold=self.confidence_threshold, + nms_threshold=NMS_THRESHOLD, + top_k=TOP_K, + keep_top_k=KEEP_TOP_K, + ) + + frame = frames[i] + for box in filtered_boxes: + x1, y1, x2, y2 = map(int, box) + # Ignore bounding boxes smaller than the minimum size + if x2 - x1 < min_size[0] or y2 - y1 < min_size[1]: + continue + max_h, max_w = frame.shape[:2] + face_roi = frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] + blurred_face = pixelate_face(face_roi) + frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] = blurred_face + blurred_frames.append(frame) + + return blurred_frames + + def postprocess(self, frames: np.ndarray) -> np.ndarray: + """Blur faces in a sequence of frames. + + Args: + frames: Input frames + + Returns: + Processed frames with pixelated faces + """ + # Create dataset and dataloader + frames_tensor = self.preprocess_frames(frames) + dataset = TensorDataset(frames_tensor) + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) + processed_frames, processed_batches = [], [] + + prior_data, scale = None, None + for i, batch in enumerate(dataloader): + batch = batch[0] + h, w = batch.shape[-2:] # Batch shape: [C, H, W] + + with torch.no_grad(): + # Generate priors for the video + if prior_data is None: + priorbox = PriorBox(self.cfg, image_size=(h, w)) + priors = priorbox.forward() + priors = priors.to(self.device, dtype=self.dtype) + prior_data = priors.data + + # Get scale for resizing detections + if scale is None: + scale = torch.Tensor([w, h, w, h]) + scale = scale.to(self.device, dtype=self.dtype) + + batch_loc, batch_conf, _ = self.net(batch) + + # Blur detected faces in each batch of frames + start_idx = i * self.batch_size + end_idx = min(start_idx + self.batch_size, len(frames)) + processed_batches.append( + self.blur_detected_faces(frames[start_idx:end_idx], batch_loc, batch_conf, prior_data, scale) + ) + + processed_frames = [frame for batch in processed_batches for frame in batch] + return np.array(processed_frames) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos") + parser.add_argument("--output_dir", type=str, required=True, help="Path for saving processed videos") + parser.add_argument( + "--checkpoint-dir", + type=str, + help="Path to the RetinaFace checkpoint file", + ) + return parser.parse_args() + + +def main(args): + filepaths = get_video_filepaths(args.input_dir) + if not filepaths: + log.error(f"No video files found in directory: {args.input_dir}") + return + + face_blur = RetinaFaceFilter(checkpoint_dir=args.checkpoint) + postprocessing_runner = GuardrailRunner(postprocessors=[face_blur]) + os.makedirs(args.output_dir, exist_ok=True) + + for filepath in tqdm(filepaths): + video_data = read_video(filepath) + with misc.timer("face blur filter"): + frames = postprocessing_runner.postprocess(video_data.frames) + + output_path = os.path.join(args.output_dir, os.path.basename(filepath)) + save_video(output_path, frames, video_data.fps) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea87e5818d643fb13bf59950c40c24d0cf36acf --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from retinaface.utils.nms.py_cpu_nms import py_cpu_nms + +from cosmos_predict1.utils import log + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def filter_detected_boxes(boxes, scores, confidence_threshold, nms_threshold, top_k, keep_top_k): + """Filter boxes based on confidence score and remove overlapping boxes using NMS.""" + # Keep detections with confidence above threshold + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + scores = scores[inds] + + # Sort by confidence and keep top K detections + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + scores = scores[order] + + # Run non-maximum-suppression (NMS) to remove overlapping boxes + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + dets = dets[:keep_top_k, :] + boxes = dets[:, :-1] + return boxes + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/utils/box_utils.py to handle batched inputs +def decode_batch(loc, priors, variances): + """Decode batched locations from predictions using priors and variances. + + Args: + loc (tensor): Batched location predictions for loc layers. + Shape: [batch_size, num_priors, 4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors, 4] + variances: (list[float]): Variances of prior boxes. + + Return: + Decoded batched bounding box predictions + Shape: [batch_size, num_priors, 4] + """ + batch_size = loc.size(0) + priors = priors.unsqueeze(0).expand(batch_size, -1, -1) + + boxes = torch.cat( + ( + priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1]), + ), + dim=2, + ) + + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def _check_keys(model, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(model.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + unused_pretrained_keys = ckpt_keys - model_keys + missing_keys = model_keys - ckpt_keys + log.debug("Missing keys:{}".format(len(missing_keys))) + log.debug("Unused checkpoint keys:{}".format(len(unused_pretrained_keys))) + log.debug("Used keys:{}".format(len(used_pretrained_keys))) + assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint" + return True + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def _remove_prefix(state_dict, prefix): + """Old version of the model is stored with all names of parameters sharing common prefix 'module.'""" + log.debug("Removing prefix '{}'".format(prefix)) + + def f(x): + return x.split(prefix, 1)[-1] if x.startswith(prefix) else x + + return {f(key): value for key, value in state_dict.items()} + + +# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py +def load_model(model, pretrained_path, load_to_cpu): + log.debug("Loading pretrained model from {}".format(pretrained_path)) + if load_to_cpu: + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage, weights_only=True) + else: + device = torch.cuda.current_device() + pretrained_dict = torch.load( + pretrained_path, map_location=lambda storage, loc: storage.cuda(device), weights_only=True + ) + if "state_dict" in pretrained_dict.keys(): + pretrained_dict = _remove_prefix(pretrained_dict["state_dict"], "module.") + else: + pretrained_dict = _remove_prefix(pretrained_dict, "module.") + _check_keys(model, pretrained_dict) + model.load_state_dict(pretrained_dict, strict=False) + return model diff --git a/cosmos_predict1/auxiliary/guardrail/llamaGuard3/__init__.py b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/llamaGuard3/categories.py b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/categories.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d5a95d4dce1202e3acec0e10177c97c1e5924e --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/categories.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +UNSAFE_CATEGORIES = { + "S1": "Violent Crimes.", + "S2": "Non-Violent Crimes.", + "S3": "Sex Crimes.", + "S4": "Child Exploitation.", + "S5": "Defamation.", + "S6": "Specialized Advice.", + "S7": "Privacy.", + "S8": "Intellectual Property.", + "S9": "Indiscriminate Weapons.", + "S10": "Hate.", + "S11": "Self-Harm.", + "S12": "Sexual Content.", + "S13": "Elections.", + "s14": "Code Interpreter Abuse.", +} diff --git a/cosmos_predict1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py new file mode 100644 index 0000000000000000000000000000000000000000..cae384738ed228bf4455d95c3c6fe011d790c01f --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from cosmos_predict1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner +from cosmos_predict1.auxiliary.guardrail.llamaGuard3.categories import UNSAFE_CATEGORIES +from cosmos_predict1.utils import log, misc + +SAFE = misc.Color.green("SAFE") +UNSAFE = misc.Color.red("UNSAFE") + + +class LlamaGuard3(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + self.checkpoint_dir = checkpoint_dir + self.device = device + self.dtype = torch.bfloat16 + + model_id = "meta-llama/Llama-Guard-3-8B" + model_dir = os.path.join(self.checkpoint_dir, model_id) + + self.model = AutoModelForCausalLM.from_pretrained(model_dir) + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + + self.model.to(self.device, dtype=self.dtype).eval() + + def get_llamaGuard3_block_message(self, moderation_output: str) -> str: + """Extract the blocked category from the Llama Guard 3 model output.""" + block_msg = "Prompt blocked by Llama Guard 3." + try: + lines = moderation_output.splitlines() + categories_detected = [] + for line in lines[1:]: + line_stripped = line.split("<|eot_id|>")[0].strip() + for catagory in line_stripped.split(","): + catagory = catagory.strip() + if catagory not in UNSAFE_CATEGORIES: + log.warning(f"Unrecognized category from moderation output: {catagory}") + else: + categories_detected.append(catagory) + if len(categories_detected) > 0: + blocked_catagories = ", ".join([UNSAFE_CATEGORIES[catagory][:-1] for catagory in categories_detected]) + block_msg = f"{block_msg} Violations: {blocked_catagories}." + except Exception as e: + log.warning(f"Unable to extract blocked category from Llama Guard 3 output: {e}") + return block_msg + + def filter_llamaGuard3_output(self, prompt: str) -> tuple[bool, str]: + """Filter the Llama Guard 3 model output and return the safety status and message.""" + conversation = [{"role": "user", "content": prompt}] + input_ids = self.tokenizer.apply_chat_template( + conversation, categories=UNSAFE_CATEGORIES, return_tensors="pt" + ).to("cuda") + prompt_len = input_ids.shape[1] + output = self.model.generate( + input_ids=input_ids, + max_new_tokens=100, + return_dict_in_generate=True, + pad_token_id=0, + ) + generated_tokens = output.sequences[:, prompt_len:] + moderation_output = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=False).strip() + + if "unsafe" in moderation_output.lower(): + block_msg = self.get_llamaGuard3_block_message(moderation_output) + return False, block_msg + else: + return True, "" + + def is_safe(self, prompt: str) -> tuple[bool, str]: + """Check if the input prompt is safe according to the Llama Guard 3 model.""" + try: + return self.filter_llamaGuard3_output(prompt) + except Exception as e: + log.error(f"Unexpected error occurred when running Llama Guard 3 guardrail: {e}") + return True, "Unexpected error occurred when running Llama Guard 3 guardrail." + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", type=str, required=True, help="Input prompt") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Llama Guard 3 checkpoint folder", + ) + return parser.parse_args() + + +def main(args): + llamaGuard3 = LlamaGuard3(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[llamaGuard3]) + with misc.timer("Llama Guard 3 safety check"): + safety, message = runner.run_safety_check(args.prompt) + log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") + log.info(f"Message: {message}") if not safety else None + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/__init__.py b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/model.py b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6dabaa352257260cb6f6462e86f4d966d1b67118 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/model.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import attrs +import torch +import torch.nn as nn + +from cosmos_predict1.utils.config import make_freezable + + +@make_freezable +@attrs.define(slots=False) +class ModelConfig: + input_size: int = 1152 + num_classes: int = 7 + + +class SafetyClassifier(nn.Module): + def __init__(self, input_size: int = 1024, num_classes: int = 2): + super().__init__() + self.input_size = input_size + self.num_classes = num_classes + self.layers = nn.Sequential( + nn.Linear(self.input_size, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, self.num_classes), + # Note: No activation function here; CrossEntropyLoss expects raw logits + ) + + def forward(self, x): + return self.layers(x) + + +class VideoSafetyModel(nn.Module): + def __init__(self, config: ModelConfig) -> None: + super().__init__() + self.config = config + self.num_classes = config.num_classes + self.network = SafetyClassifier(input_size=config.input_size, num_classes=self.num_classes) + + @torch.inference_mode() + def forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + logits = self.network(data_batch["data"].cuda()) + return {"logits": logits} diff --git a/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..6bccb8dad67e1baf2649d6c6d83f29e5a09f8445 --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py @@ -0,0 +1,183 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +from typing import Iterable, Tuple, Union + +import torch +from PIL import Image + +from cosmos_predict1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner +from cosmos_predict1.auxiliary.guardrail.common.io_utils import get_video_filepaths, read_video +from cosmos_predict1.auxiliary.guardrail.video_content_safety_filter.model import ModelConfig, VideoSafetyModel +from cosmos_predict1.auxiliary.guardrail.video_content_safety_filter.vision_encoder import SigLIPEncoder +from cosmos_predict1.utils import log, misc + +# Define the class index to class name mapping for multi-class classification +CLASS_IDX_TO_NAME = { + 0: "Safe", + 1: "Sexual_Content", + 2: "Violence", + 3: "Drugs", + 4: "Child_Abuse", + 5: "Hate_and_Harassment", + 6: "Self-Harm", +} + + +class VideoContentSafetyFilter(ContentSafetyGuardrail): + def __init__( + self, + checkpoint_dir: str, + device="cuda" if torch.cuda.is_available() else "cpu", + ) -> None: + self.device = device + self.dtype = torch.float32 + + # Initialize the SigLIP encoder + self.checkpoint_dir = os.path.join(checkpoint_dir, "nvidia/Cosmos-Guardrail1/video_content_safety_filter") + self.encoder = SigLIPEncoder(checkpoint_dir=self.checkpoint_dir, device=device, dtype=self.dtype) + + # Use ModelConfig directly for inference configuration + model_config = ModelConfig(input_size=1152, num_classes=7) + + # Load the multi-class classifier + self.model = VideoSafetyModel(model_config) + safety_filter_local_path = os.path.join(self.checkpoint_dir, "safety_filter.pt") + checkpoint = torch.load(safety_filter_local_path, map_location=torch.device("cpu"), weights_only=True) + self.model.load_state_dict(checkpoint["model"]) + self.model.to(self.device, dtype=self.dtype).eval() + + @torch.inference_mode() + def __infer(self, pil_image: Image.Image) -> int: + """Infer the class of the image.""" + image_embs = self.encoder.encode_image(pil_image) + logits = self.model.network(image_embs) + probabilities = torch.nn.functional.softmax(logits, dim=-1) + predicted_class = torch.argmax(probabilities, dim=-1).item() + return predicted_class + + def is_safe_file(self, filepath: str) -> bool: + """Check if the video file is safe.""" + video_data = read_video(filepath) + + # Sample frames at 2 FPS + sample_rate = 2 # frames per second + frame_interval = int(video_data.fps / sample_rate) + frame_numbers = list(range(0, int(video_data.fps * video_data.duration), frame_interval)) + + is_safe = True + frame_scores = [] + + for frame_number in frame_numbers: + try: + frame = video_data.frames[frame_number] + pil_image = Image.fromarray(frame) + predicted_class = self.__infer(pil_image) + class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Unknown") + frame_scores.append({"frame_number": frame_number, "class": class_name}) + + # If any frame is not "Safe", mark the video as unsafe + if predicted_class != 0: + is_safe = False + break + + except Exception as e: + log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}") + continue + + # Prepare data for JSON + video_data = { + "filepath": filepath, + "is_safe": is_safe, + "video_length": video_data.duration, + "fps": video_data.fps, + "frame_scores": frame_scores, + } + + log.info(f"Video {filepath} is {'SAFE' if is_safe else 'UNSAFE'}.") + log.debug(f"Video data: {json.dumps(video_data, indent=4)}") + return is_safe + + def is_safe_frames(self, frames: Iterable) -> bool: + """Check if the video frames are safe.""" + is_safe = True + frame_scores = [] + + for frame_number, frame in enumerate(frames): + try: + pil_image = Image.fromarray(frame) + predicted_class = self.__infer(pil_image) + class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Unknown") + frame_scores.append({"frame_number": frame_number, "class": class_name}) + + # If any frame is not "Safe", mark as not safe + if predicted_class != 0: + is_safe = False + break + + except Exception as e: + log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}") + continue + + video_data = { + "is_safe": is_safe, + "frame_scores": frame_scores, + } + + log.debug(f"Frames data: {json.dumps(video_data, indent=4)}") + return is_safe + + def is_safe(self, input: Union[str, Iterable]) -> Tuple[bool, str]: + if isinstance(input, str): + is_safe = self.is_safe_file(input) + return is_safe, "safe video detected" if is_safe else "unsafe video detected" + elif isinstance(input, Iterable): + is_safe = self.is_safe_frames(input) + return is_safe, "safe frames detected" if is_safe else "unsafe frames detected" + else: + raise ValueError(f"Input type {type(input)} not supported.") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos") + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the Video Content Safety Filter checkpoint folder", + ) + return parser.parse_args() + + +def main(args): + filepaths = get_video_filepaths(args.input_dir) + if not filepaths: + log.error(f"No video files found in directory: {args.input_dir}") + return + + video_filter = VideoContentSafetyFilter(checkpoint_dir=args.checkpoint_dir) + runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe") + + for filepath in filepaths: + with misc.timer("video content safety filter"): + _ = runner.run_safety_check(filepath) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6c7b45422501fec96bd1e711509ead5efa019a --- /dev/null +++ b/cosmos_predict1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from PIL import Image +from transformers import SiglipModel, SiglipProcessor + + +class SigLIPEncoder(torch.nn.Module): + def __init__( + self, + checkpoint_dir: str, + model_name: str = "google/siglip-so400m-patch14-384", + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.float32, + ) -> None: + super().__init__() + self.checkpoint_dir = checkpoint_dir + self.device = device + self.dtype = dtype + self.model = SiglipModel.from_pretrained(model_name, cache_dir=self.checkpoint_dir) + self.processor = SiglipProcessor.from_pretrained(model_name, cache_dir=self.checkpoint_dir) + self.model.to(self.device, dtype=self.dtype).eval() + + @torch.inference_mode() + def encode_image(self, input_img: Image.Image) -> torch.Tensor: + """Encode an image into a feature vector.""" + with torch.no_grad(): + inputs = self.processor(images=input_img, return_tensors="pt").to(self.device, dtype=self.dtype) + image_features = self.model.get_image_features(**inputs) + image_features /= image_features.norm(dim=-1, keepdim=True) + return image_features diff --git a/cosmos_predict1/auxiliary/t5_text_encoder.py b/cosmos_predict1/auxiliary/t5_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..53f34ba014e36c8225635abaea5acc2d2c95e008 --- /dev/null +++ b/cosmos_predict1/auxiliary/t5_text_encoder.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple, Union + +import torch +import transformers +from transformers import T5EncoderModel, T5TokenizerFast + +from cosmos_predict1.utils import log + +transformers.logging.set_verbosity_error() + + +class CosmosT5TextEncoder(torch.nn.Module): + """Handles T5 text encoding operations.""" + + def __init__(self, model_name: str = "google-t5/t5-11b", device: str = "cuda", cache_dir: str = "~/.cache"): + """Initializes the T5 tokenizer and encoder. + + Args: + model_name: The name of the T5 model to use. + device: The device to use for computations. + """ + super().__init__() + try: + self.tokenizer = T5TokenizerFast.from_pretrained(cache_dir, cache_dir=cache_dir) + self.text_encoder = T5EncoderModel.from_pretrained(cache_dir, cache_dir=cache_dir).to(device) + except Exception as e: + log.warning(f"Failed to load T5 model using cache_dir '{cache_dir}', falling back to default location: {e}") + self.tokenizer = T5TokenizerFast.from_pretrained(model_name) + self.text_encoder = T5EncoderModel.from_pretrained(model_name).to(device) + self.text_encoder.eval() + self.device = device + + @torch.inference_mode() + def encode_prompts( + self, prompts: Union[str, List[str]], max_length: int = 512 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encodes text prompts into hidden state representations using a T5 encoder. + + This function tokenizes the input prompts, processes them through a T5 text encoder, + and returns the last hidden states. The encoded outputs beyond the actual sequence + length are zero-padded. All prompts in a batch are padded to max_length. + + Args: + prompts: Input text to encode. Can be a single string or a list of strings. + max_length: Maximum sequence length for tokenization and padding. Longer + sequences will be truncated. Defaults to 512. + return_mask: If True, returns the attention mask along with encoded text. + Defaults to False. + + Returns: + If return_mask is False: + torch.Tensor: Encoded text embeddings of shape (batch_size, max_length, hidden_size). + If return_mask is True: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Encoded text embeddings of shape (batch_size, max_length, hidden_size) + - Attention mask of shape (batch_size, max_length) as boolean tensor + + Raises: + ValueError: If the input prompts list is empty. + + Example: + >>> encoder = CosmosT5TextEncoder() + >>> prompts = ["Hello world", "Another example"] + >>> embeddings = encoder.encode_prompts(prompts, max_length=128) + """ + if isinstance(prompts, str): + prompts = [prompts] + + if not prompts: + raise ValueError("The input prompt list is empty.") + + batch_encoding = self.tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + input_ids = batch_encoding.input_ids.to(self.device) + attn_mask = batch_encoding.attention_mask.to(self.device) + + outputs = self.text_encoder(input_ids=input_ids, attention_mask=attn_mask) + + encoded_text = outputs.last_hidden_state + lengths = attn_mask.sum(dim=1).cpu() + + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + return encoded_text, attn_mask diff --git a/cosmos_predict1/callbacks/every_n.py b/cosmos_predict1/callbacks/every_n.py new file mode 100644 index 0000000000000000000000000000000000000000..25cab309a58336867ed5fc58849e71db7611d0f3 --- /dev/null +++ b/cosmos_predict1/callbacks/every_n.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Optional + +import torch + +from cosmos_predict1.utils import distributed, log +from cosmos_predict1.utils.callback import Callback +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +class EveryN(Callback): + def __init__( + self, + every_n: Optional[int] = None, + step_size: int = 1, + barrier_after_run: bool = True, + run_at_start: bool = False, + ) -> None: + """Constructor for `EveryN`. + + Args: + every_n (int): Frequency with which callback is run during training. + step_size (int): Size of iteration step count. Default 1. + barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts. + run_at_start (bool): Whether to run at the beginning of training. Default False. + """ + self.every_n = every_n + if self.every_n == 0: + log.warning( + f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped." + ) + + self.step_size = step_size + self.barrier_after_run = barrier_after_run + self.run_at_start = run_at_start + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + # every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training + if self.every_n != 0: + trainer = self.trainer + global_step = iteration // self.step_size + should_run = (iteration == 1 and self.run_at_start) or ( + global_step % self.every_n == 0 + ) # (self.every_n - 1) + if should_run: + log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}") + self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration) + log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}") + # add necessary barrier to avoid timeout + if self.barrier_after_run: + distributed.barrier() + + @abstractmethod + def every_n_impl( + self, + trainer: Trainer, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int, + ) -> None: + ... diff --git a/cosmos_predict1/callbacks/grad_clip.py b/cosmos_predict1/callbacks/grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4f320b6f79e1e289117d8190b5f6df52cf64ae --- /dev/null +++ b/cosmos_predict1/callbacks/grad_clip.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from cosmos_predict1.utils import distributed +from cosmos_predict1.utils.callback import Callback + + +@torch.jit.script +def _fused_nan_to_num(params: List[torch.Tensor]): + for param in params: + torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param) + + +class GradClip(Callback): + def __init__( + self, clip_norm=1.0, force_finite: bool = True, model_key: Optional[str] = None, fsdp_enabled: bool = False + ): + self.clip_norm = clip_norm + self.force_finite = force_finite + self.model_key = model_key + self.fsdp_enabled = fsdp_enabled + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + del optimizer, scheduler + if isinstance(model_ddp, distributed.DistributedDataParallel): + model = model_ddp.module + else: + model = model_ddp + + # select sub-network if specified + if self.model_key is not None: + items = self.model_key.split(".") + for item in items: + model = getattr(model, item) + + if self.force_finite: + params = [] + for param in model.parameters(): + if param.grad is not None: + params.append(param.grad) + # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) + _fused_nan_to_num(params) + + # check if FSDP is used + # total_norm + if isinstance(model, FSDP) and self.fsdp_enabled: + model.clip_grad_norm_(self.clip_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) diff --git a/cosmos_predict1/checkpointer/__init__.py b/cosmos_predict1/checkpointer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/checkpointer/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/checkpointer/base.py b/cosmos_predict1/checkpointer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4721905f50c45eae13dac833754303e9923c3a33 --- /dev/null +++ b/cosmos_predict1/checkpointer/base.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +from cosmos_predict1.utils import callback +from cosmos_predict1.utils.config import CheckpointConfig, JobConfig +from cosmos_predict1.utils.easy_io import easy_io +from cosmos_predict1.utils.model import Model + + +class AbstractCheckpointer(ABC): + """The checkpointer class. Supports checkpoint saving/loading to local disk.""" + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + """Constructor of the checkpointer. + + Args: + config_checkpoint (CheckpointConfig): The config object for the checkpointer. + """ + self.config_checkpoint = config_checkpoint + # Set the callback functions. + self.callbacks = callbacks + + # Set checkpoint directories for local paths + self._local_dirname = os.path.join(config_job.path_local, "checkpoints") + + self.strict_resume = config_checkpoint.strict_resume + self.load_path = config_checkpoint.load_path or None + self.load_training_state = config_checkpoint.load_training_state + self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state + self.save_thread = None + self.verbose = config_checkpoint.verbose + self.keys_not_to_resume = config_checkpoint.keys_not_to_resume + self.broadcast_via_filesystem = config_checkpoint.broadcast_via_filesystem + + @abstractmethod + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + pass + + @abstractmethod + def load( + self, + model: Model, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + grad_scaler: Optional[torch.amp.GradScaler] = None, + ) -> int: + pass + + @property + def save_bucket(self): + """Get the bucket name for saving checkpoints.""" + return None + + @property + def load_bucket(self): + """Get the bucket name for loading checkpoints.""" + return None + + @property + def save_dirname(self): + return self._local_dirname + + @property + def load_dirname(self): + return self._local_dirname + + def finalize(self) -> None: + """Finalize the checkpointer.""" + if self.save_thread: + self.save_thread.join() + + def _read_latest_checkpoint_file(self) -> str | None: + """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. + + Returns: + checkpoint_file (str | None): file name of the latest saved checkpoint. + """ + checkpoint_file = None + checkpoint_path = os.path.join(self.load_dirname, "latest_checkpoint.txt") + if easy_io.exists(checkpoint_path): + checkpoint_file = easy_io.load(checkpoint_path).strip() + + return checkpoint_file + + def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: + """Track the file name of the latest saved checkpoint. + + Args: + checkpoint_file (str): file name of the latest saved checkpoint. + """ + content = f"{checkpoint_file}\n" + checkpoint_path = os.path.join(self.save_dirname, "latest_checkpoint.txt") + easy_io.dump(content, checkpoint_path) + + def _check_checkpoint_exists(self, checkpoint_path: str) -> None: + """If the file checkpoint_path does not exist, raise an error. + + Args: + checkpoint_path (str): full path to the checkpoint. + """ + if not easy_io.exists(checkpoint_path): + raise FileNotFoundError(f"File not found: {checkpoint_path}") diff --git a/cosmos_predict1/checkpointer/ddp.py b/cosmos_predict1/checkpointer/ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..cee4cf147f6882731239d3925355de581d7e9f1f --- /dev/null +++ b/cosmos_predict1/checkpointer/ddp.py @@ -0,0 +1,437 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +from collections import namedtuple +from typing import Any, Dict, Optional, Set, Tuple, Union + +import torch +import torch.distributed +from megatron.core import parallel_state +from torch.distributed import ProcessGroup, get_process_group_ranks + +from cosmos_predict1.checkpointer.base import AbstractCheckpointer +from cosmos_predict1.checkpointer.safe_broadcast import broadcast_object +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.easy_io import easy_io +from cosmos_predict1.utils.model import Model + +StateDictItemPath = namedtuple("StateDictItemPath", ["state_dict", "save_path"]) + + +class Checkpointer(AbstractCheckpointer): + """ + Checkpointer for DDP. + Note: This implementation only supports local filesystem. + """ + + KEYS_TO_SAVE = ["model", "optim", "scheduler", "trainer"] + KEYS_TO_POSTFIX = { + "model": "model", + "optim": "optim", + "scheduler": "scheduler", + "trainer": "", + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + ep_world_size = parallel_state.get_expert_model_parallel_world_size() + assert pp_world_size < 2, "Pipeline Parallelism (PP) is not tested yet." + assert ep_world_size < 2, "Expert Parallelism (EP) is not tested yet." + self.mp_world_size = parallel_state.get_model_parallel_group().size() + if self.mp_world_size > 1 and self.__class__ == Checkpointer: + raise NotImplementedError( + "Model Parallelism (MP) is enabled - " + "you should use TensorParallel Checkpointer instead of DDP Checkpointer." + ) + # DDP rank (with context parallelism considered) + self.rank_dp_w_cp = parallel_state.get_data_parallel_rank(with_context_parallel=True) + # Context parallelism rank + self.cp_rank = parallel_state.get_context_parallel_rank() + # Model parallelism rank (including Tensor+Pipeline+Expert Parallelisms) + self.mp_rank = parallel_state.get_model_parallel_group().rank() + # self.mp_rank = parallel_state.get_model_parallel_group(with_expert_parallel=ep_world_size > 1).rank() + if self.broadcast_via_filesystem: + log.info("Broadcasting checkpoint data via the local filesystem.") + if not self.strict_resume: + log.warning("Strict resume mode is off. Some model parameters may not be loaded.") + + # collect ranks of all model parallel groups + all_ranks = [None for _ in range(distributed.get_world_size())] + torch.distributed.all_gather_object( + all_ranks, get_process_group_ranks(parallel_state.get_model_parallel_group()) + ) + all_ranks = list(set(tuple(rank) if isinstance(rank, list) else rank for rank in all_ranks)) + for ranks in all_ranks: + group = torch.distributed.new_group(list(ranks), backend="gloo") + if distributed.get_rank() in ranks: + self.mp_gloo_pg = group + + self.print("Checkpointer Initialized.") + + def print(self, message: str): + """ + Print message to the console. Include the parallelism rank information when verbose is set to True. + """ + if self.verbose: + log.info( + f"[Parallelism Rank: DP-{self.rank_dp_w_cp}, TP-{self.mp_rank}, CP-{self.cp_rank}]: {message}", + rank0_only=False, + ) + else: + log.info(message, rank0_only=True) + + def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: + del model + assert key in self.KEYS_TO_SAVE + post_fix = self.KEYS_TO_POSTFIX[key] + + if post_fix: + _ckpt_path = checkpoint_path.replace(".pt", f"_{post_fix}.pt") + else: + _ckpt_path = checkpoint_path + return _ckpt_path + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + **ignore_kwargs, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + checkpoint_file = self.format_checkpoint_filename(model, iteration) + state_dict = self.generate_save_state_dict(model, optimizer, scheduler, grad_scaler, iteration) + state_dict = self._map_state_dict_path_during_save(state_dict, checkpoint_file, model) + if state_dict: + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + def _map_state_dict_path_during_save(self, state_dict, checkpoint_file, model) -> dict[str, StateDictItemPath]: + new_dict = {} + for key, _state_dict in state_dict.items(): + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_file, model) + checkpoint_path = os.path.join(self.save_dirname, _ckpt_path) + new_dict[key] = StateDictItemPath(_state_dict, checkpoint_path) + return new_dict + + @misc.timer("checkpoint saving") + def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None: + """Worker to save checkpoint to disk, spawned with a child thread (in parallel with the training). + + Args: + state_dict (dict[str, StateDictItemPath]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + try: + for key, item in state_dict.items(): + self.print(f"Saving {key} to {item.save_path}") + try: + easy_io.dump( + item.state_dict, + item.save_path, + fast_backend=True, # optional for fast backend, cpu heavy + ) + self.print(f"Saved {key} to {item.save_path}") + except Exception as e: + self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") + raise # Re-raise the exception after logging + + # Synchronize only rank 0 of each model parallel group + if self.mp_world_size > 1: + torch.distributed.barrier(group=self.mp_gloo_pg) + + # Only rank 0 of MP group and rank 0 of DP with CP updates latest_checkpoint.txt + if self.mp_rank == 0 and self.rank_dp_w_cp == 0: + self._write_latest_checkpoint_file(checkpoint_file) + + if distributed.get_rank() == 0: # only rank 0 saves trained_data_record + if "trained_data_record" in state_dict["model"].state_dict: + self._write_trained_data_record( + checkpoint_file, state_dict["model"].state_dict["trained_data_record"] + ) + + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose) + + def format_checkpoint_filename(self, model: Model, iteration: int) -> str: + """Generate the checkpoint file name. + + Args: + iteration (int): The current iteration number. + + Returns: + checkpoint_file (str): The checkpoint file name. + """ + del self, model + return f"iter_{iteration:09}.pt" + + @misc.timer("generate saving state dict") + def generate_save_state_dict( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> Optional[Dict[str, Any]]: + state_dict = {} + + if self.rank_dp_w_cp == 0: + trainer_state = dict( + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + model_state = model.state_dict() + optim_state = optimizer.state_dict() + scheduler_state = scheduler.state_dict() + self.callbacks.on_save_checkpoint(model, state_dict=trainer_state) + + trainer_state, model_state, optim_state, scheduler_state = misc.to( + [trainer_state, model_state, optim_state, scheduler_state], device="cpu" + ) + + state_dict = { + "model": model_state, + "optim": optim_state, + "scheduler": scheduler_state, + } + if distributed.get_rank() == 0: # only rank 0 saves trainer state + state_dict["trainer"] = trainer_state + return state_dict + return state_dict + + def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]: + """ + Load state_dict and broadcast. + + The main steps are: + 1. Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + 2. Each rank loads its corresponding checkpoint from the local cache or receives it via broadcast. + + This approach ensures that each MP rank loads its specific part of the model, which is + crucial for Model Parallelism where different parts of the model are distributed across + multiple GPUs. + + When using Model Parallelism (e.g., Tensor Parallelism), the `broadcast_via_filesystem` option can + be set to True. This allows each rank to load its specific checkpoint from the local filesystem + instead of receiving it via network broadcast, which could be more efficient in some cases. + + For standard DDP without TP, `broadcast_via_filesystem` should remain False (default). + + Args: + checkpoint_path (str): The base path of the checkpoint. + model (Model): The model being loaded. + resume_keys (Set): Set of keys to resume from the checkpoint. + + Returns: + dict[str, Any]: A dictionary containing the loaded state for each resumed key. + """ + state_dict = {} + sorted_resume_keys = sorted(resume_keys) + # Step 1: Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + if self.rank_dp_w_cp == 0: + for key in sorted_resume_keys: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + if os.path.exists(local_cache_path): + # If the local checkpoint exists, we can directly load it + self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...") + _state_dict = easy_io.load(local_cache_path, fast_backend=True) + else: + self.print(f"Downloading checkpoint from: {_ckpt_path}") + _state_dict = easy_io.load(_ckpt_path, fast_backend=True) + if self.broadcast_via_filesystem: + # Save the checkpoint to the local filesystem + easy_io.dump(_state_dict, local_cache_path, fast_backend=True) + state_dict[key] = _state_dict + # Ensure all ranks wait for the download to complete + distributed.barrier() + + # Step 2: Broadcast checkpoint data + log.info( + "Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.", + rank0_only=True, + ) + for key in sorted_resume_keys: + if self.broadcast_via_filesystem: + # Load the checkpoint from the local filesystem for other ranks + if self.rank_dp_w_cp != 0: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + self.print(f"Loading checkpoint from: {local_cache_path}") + state_dict[key] = easy_io.load(local_cache_path, fast_backend=True) + else: + # Broadcast the checkpoint to all GPUs of the current DDP rank + group: ProcessGroup = parallel_state.get_data_parallel_group(with_context_parallel=True) + min_rank = min(get_process_group_ranks(group)) + + _state_dict = broadcast_object( + state_dict[key] if self.rank_dp_w_cp == 0 else None, + min_rank, + group=group, + device=torch.device(torch.cuda.current_device()), + ) + if self.rank_dp_w_cp == 0: + self.print(f'Broadcasted checkpoint["{key}"] to all other ranks in the same DDP group.') + else: + state_dict[key] = _state_dict + self.print(f'Received checkpoint["{key}"] from source rank {min_rank}.') + + return state_dict + + def keys_to_resume_during_load(self) -> Tuple[Set, Union[str, None]]: + latest_checkpoint_file = self._read_latest_checkpoint_file() + + resume_keys = [] + + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_path = os.path.join(self.load_dirname, latest_checkpoint_file) + resume_keys.extend(self.KEYS_TO_SAVE) + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + if self.load_training_state: + resume_keys.extend(self.KEYS_TO_SAVE) + else: + resume_keys.append("model") + if self.only_load_scheduler_state: + resume_keys.append("scheduler") + else: + checkpoint_path = None + if len(self.keys_not_to_resume) > 0: + for key in self.keys_not_to_resume: + assert key in self.KEYS_TO_SAVE, f"Invalid key to resume: {key} not in {self.KEYS_TO_SAVE}" + resume_keys = [key for key in resume_keys if key not in self.keys_not_to_resume] + return set(resume_keys), checkpoint_path + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + resume_keys, checkpoint_path = self.keys_to_resume_during_load() + + iteration = 0 + + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + state_dict = self.load_broadcast_state_dict(checkpoint_path, model, set(resume_keys)) + + if "trainer" in state_dict: + trainer_state = state_dict["trainer"] + log.critical(state_dict.keys(), rank0_only=False) + log.critical(trainer_state, rank0_only=False) + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(trainer_state["grad_scaler"]) + self.callbacks.on_load_checkpoint(model, state_dict=trainer_state) + iteration = trainer_state["iteration"] + if "optim" in state_dict: + assert optimizer + optimizer_state = state_dict["optim"] + log.info("- Loading the optimizer...") + optimizer.load_state_dict(optimizer_state) + if "scheduler" in state_dict: + assert scheduler + scheduler_state = state_dict["scheduler"] + log.info("- Loading the scheduler...") + scheduler.load_state_dict(scheduler_state) + scheduler.last_epoch = iteration + if "model" in state_dict: + model_state = state_dict["model"] + log.info("- Loading the model...") + # model.load_state_dict(model_state) + if self.strict_resume: + log.info("\t Strict resume mode is on.") + else: + log.info("\t Strict resume mode is off.") + model_load_info = model.load_state_dict(model_state, strict=self.strict_resume) + log.info(f"\t {model_load_info}") + self.print(f"Loaded checkpoint from {checkpoint_path} in iteration {iteration}") + else: + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + return iteration + + def _write_trained_data_record(self, checkpoint_file: str, trained_data_record: dict[str, int]) -> None: + """Write json file to save number of seen samples and number of iterations. + + Args: + checkpoint_file (str): iteration number for the saved checkpoint + trained_data_record (dict[str, int]): example {"image": 0, "video": 0, "iteration": 0}. + """ + # filename: iter_xxxxxxxxx_trained_data_record.json + checkpoint_path = os.path.join( + self.save_dirname, f"{checkpoint_file.replace('.pt', '')}_trained_data_record.json" + ) + easy_io.dump(trained_data_record, checkpoint_path) diff --git a/cosmos_predict1/checkpointer/peft_checkpointer.py b/cosmos_predict1/checkpointer/peft_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7a3341042d35d7da6a7e3c4955d400ce73dd6a --- /dev/null +++ b/cosmos_predict1/checkpointer/peft_checkpointer.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, Set + +import torch + +from cosmos_predict1.checkpointer.ddp import Checkpointer as DDPCheckpointer +from cosmos_predict1.utils import distributed, log +from cosmos_predict1.utils.model import Model + + +class Checkpointer(DDPCheckpointer): + """ + Checkpointer class for PEFT in distributed training. This class is similar to the DDP checkpointer, + with the exception that the `broadcast_via_filesystem` functionality is not supported, and it supports + loading pre-trained model without any postfix. + + Note: + - Fully Sharded Data Parallelism (FSDP) is not supported by this checkpointer. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.broadcast_via_filesystem: + raise ValueError("self.broadcast_via_filesystem=False is not implemented for PEFT checkpointer.") + + def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: + """ + Overwrite the `add_type_postfix_to_checkpoint_path` function of the base class (DDP checkpointer) + to load pre-trained model without any postfix. + """ + checkpoint_path = super().add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + checkpoint_path = checkpoint_path.replace("model_model.pt", "model.pt") + return checkpoint_path + + def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]: + """ + Load state_dict and broadcast for PEFT checkpointer. + + This function is identical to the `load_broadcast_state_dict` function of the base class (DDP checkpointer), + with the exception that the `broadcast_via_filesystem` functionality is not supported. + + Args: + checkpoint_path (str): The base path of the checkpoint. + model (Model): The model being loaded. + resume_keys (Set): Set of keys to resume from the checkpoint. + + Returns: + dict[str, Any]: A dictionary containing the loaded state for each resumed key. + """ + state_dict = {} + sorted_resume_keys = sorted(resume_keys) + # Step 1: Download checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + if self.rank_dp_w_cp == 0: + for key in sorted_resume_keys: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + if os.path.exists(local_cache_path): + # If the local checkpoint exists, we can directly load it + self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...") + _state_dict = torch.load( + local_cache_path, map_location=lambda storage, loc: storage, weights_only=False + ) + else: + # Pre-trained model is not in local cache, so we need to load it from the checkpoint path + self.print(f"Loading checkpoint from: {_ckpt_path}") + _state_dict = torch.load(_ckpt_path, map_location=lambda storage, loc: storage, weights_only=False) + state_dict[key] = _state_dict + + # Ensure all ranks wait for the download to complete + distributed.barrier() + + # Step 2: Broadcast checkpoint data + log.info( + "Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.", + rank0_only=True, + ) + for key in sorted_resume_keys: + if self.broadcast_via_filesystem: + # Load the checkpoint from the local filesystem for other ranks + if self.rank_dp_w_cp != 0: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + if os.path.exists(local_cache_path): + self.print(f"Loading checkpoint from: {local_cache_path}") + state_dict[key] = torch.load( + local_cache_path, map_location=lambda storage, loc: storage, weights_only=False + ) + else: + self.print(f"Loading checkpoint from: {_ckpt_path}") + state_dict[key] = torch.load( + _ckpt_path, map_location=lambda storage, loc: storage, weights_only=False + ) + + else: + raise ValueError("self.broadcast_via_filesystem=False is not implemented for PEFT checkpointer.") + + return state_dict diff --git a/cosmos_predict1/checkpointer/safe_broadcast.py b/cosmos_predict1/checkpointer/safe_broadcast.py new file mode 100644 index 0000000000000000000000000000000000000000..f914299c97f297cf43a5919b0bc8130686afa7c1 --- /dev/null +++ b/cosmos_predict1/checkpointer/safe_broadcast.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import io +import pickle +from typing import Any + +import torch +import torch.distributed as dist + + +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/optim/zero_redundancy_optimizer.py#L29 +def broadcast_object( + obj: Any, + src_rank: int, + group: object = dist.group.WORLD, + device: torch.device = torch.device("cpu"), +) -> Any: + r""" + Broadcasts an object to the given group. + + It will be sending the object if called from the source rank and receiving + the object otherwise. + + Arguments: + obj: object to broadcast; only used if called on the source rank. + src_rank (int): source rank. + group (``ProcessGroup``, optional): group used for the broadcast + (default: ``dist.group.WORLD``). + device (``torch.device``, optional): device to send from or receive + to (default: ``torch.device("cpu")``). + + Returns: + The broadcasted object. + """ + if dist.get_rank() == src_rank: + # Send the object + buffer = io.BytesIO() + torch.save(obj, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.LongTensor([len(data)]).to(device) + data_send_tensor = torch.ByteTensor(data).to(device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) + else: + # Receive the object + length_tensor = torch.LongTensor([0]).to(device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=device) + dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) + buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + obj = torch.load(buffer, map_location=device, weights_only=False) + return obj + + +def _recursive_copy_to_device( + value: Any, + non_blocking: bool, + device: torch.device, +) -> Any: + r""" + Recursively searches lists, tuples, dicts and copies tensors to device if possible. + + Non-tensor values are passed as-is in the result. + + .. note: These are all copies, so if there are two objects that reference + the same object, then after this call, there will be two different objects + referenced on the device. + """ + if isinstance(value, torch.Tensor): + return value.to(device, non_blocking=non_blocking) + + if isinstance(value, (list, tuple)): + values = [_recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for val in value] + return values if isinstance(value, list) else tuple(values) + + if isinstance(value, collections.abc.Mapping): + return { + key: _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for key, val in value.items() + } + + return value diff --git a/cosmos_predict1/checkpointer/tp.py b/cosmos_predict1/checkpointer/tp.py new file mode 100644 index 0000000000000000000000000000000000000000..b97231a66aa13f5b26627fd7c302cbbb721492b0 --- /dev/null +++ b/cosmos_predict1/checkpointer/tp.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.checkpointer.ddp import Checkpointer as DDPCheckpointer +from cosmos_predict1.utils.model import Model + + +class Checkpointer(DDPCheckpointer): + """ + Checkpointer class for Tensor Parallelism (TP) in distributed training. + + This implementation supports the combination of Tensor Parallelism (TP) and Data Parallel Processing (DDP), with optional Context Parallelism (CP). + + Note: + - Fully Sharded Data Parallelism (FSDP) is not supported by this checkpointer. + - In principle, this implementation is also compatible with Pipeline Parallelism (PP) and Expert Parallelism (EP), which are other forms of model parallelism. However, PP and EP have not been tested yet. + """ + + def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: + """ + Overwrite the `add_type_postfix_to_checkpoint_path` function of the base class (DDP checkpointer) + to append the TP-rank postfix to the checkpoint path. + """ + checkpoint_path = super().add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + if key == "trainer": + return checkpoint_path + else: + checkpoint_path = checkpoint_path.replace(".pt", f"_mp_{self.mp_rank}.pt") + + return checkpoint_path diff --git a/cosmos_predict1/diffusion/checkpointers/ema_fsdp_checkpointer.py b/cosmos_predict1/diffusion/checkpointers/ema_fsdp_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea13c45291ab346fc482725bb1b15b58f1d9618 --- /dev/null +++ b/cosmos_predict1/diffusion/checkpointers/ema_fsdp_checkpointer.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import attrs + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.config import CheckpointConfig as BaseCheckpointConfig +from cosmos_predict1.utils.config import make_freezable +from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer as BaseFSDPCheckpointer + + +@make_freezable +@attrs.define(slots=False) +class CheckpointConfig(BaseCheckpointConfig): + load_ema_to_reg: bool = False + + +class FSDPCheckpointer(BaseFSDPCheckpointer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not isinstance(self.config_checkpoint, CheckpointConfig): + warnings.warn( + "The 'config_checkpoint' is not an instance of 'CheckpointConfig'. " + "This behavior is deprecated and will not be supported in future versions. " + "Please update 'config_checkpoint' to be of type 'CheckpointConfig'.", + DeprecationWarning, + ) + + self.load_ema_to_reg = False + else: + self.load_ema_to_reg = self.config_checkpoint.load_ema_to_reg + + log.critical(f"load_ema_to_reg: {self.load_ema_to_reg}", rank0_only=False) + + def load_model_during_init(self, model, is_ema: bool = False, ema_id: int = 0): + if self.load_ema_to_reg and is_ema is False: + is_ema = True + ema_id = 0 + log.critical("Loading EMA model to regular model during initialization.", rank0_only=False) + super().load_model_during_init(model, is_ema, ema_id) diff --git a/cosmos_predict1/diffusion/conditioner.py b/cosmos_predict1/diffusion/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6deb003d806b464c7921b1b78056b850a6045d --- /dev/null +++ b/cosmos_predict1/diffusion/conditioner.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, fields +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import instantiate + + +class BaseConditionEntry(nn.Module): + def __init__(self): + super().__init__() + + self._dropout_rate = None + self._input_key = None + self._return_dict = False + + @property + def dropout_rate(self) -> Union[float, torch.Tensor]: + return self._dropout_rate + + @property + def input_key(self) -> str: + return self._input_key + + @property + def is_return_dict(self) -> bool: + return self._return_dict + + @dropout_rate.setter + def dropout_rate(self, value: Union[float, torch.Tensor]): + self._dropout_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_return_dict.setter + def is_return_dict(self, value: bool): + self._return_dict = value + + @dropout_rate.deleter + def dropout_rate(self): + del self._dropout_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + @is_return_dict.deleter + def is_return_dict(self): + del self._return_dict + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + del key + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + bernoulli = torch.bernoulli((1.0 - dropout_rate) * torch.ones(len(in_tensor))).type_as(in_tensor) + bernoulli_expand = bernoulli.view((-1,) + (1,) * (in_tensor.dim() - 1)) + return bernoulli_expand * in_tensor + + def summary(self) -> str: + pass + + +class DataType(Enum): + IMAGE = "image" + VIDEO = "video" + + +class TextAttr(BaseConditionEntry): + def __init__(self): + super().__init__() + + def forward(self, token: torch.Tensor, mask: torch.Tensor): + return {"crossattn_emb": token, "crossattn_mask": mask} + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + if key is not None and "mask" in key: + return in_tensor + return super().random_dropout_input(in_tensor, dropout_rate, key) + + +@dataclass +class BaseVideoCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + data_type: DataType = DataType.VIDEO + padding_mask: Optional[torch.Tensor] = None + fps: Optional[torch.Tensor] = None + num_frames: Optional[torch.Tensor] = None + image_size: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + frame_repeat: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +@dataclass +class VideoExtendCondition(BaseVideoCondition): + video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video + gt_latent: Optional[torch.Tensor] = None + condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region + + # condition_video_input_mask will concat to the input of network, along channel dim; + # Will be concat with the input tensor + condition_video_input_mask: Optional[torch.Tensor] = None + # condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" + condition_video_augment_sigma: Optional[torch.Tensor] = None + condition_video_pose: Optional[torch.Tensor] = None + + +class GeneralConditioner(nn.Module, ABC): + """ + An abstract module designed to handle various embedding models with conditional and + unconditional configurations. This abstract base class initializes and manages a collection + of embedders that can dynamically adjust their dropout rates based on conditioning. + + Attributes: + KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation. + embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and + configured based on the provided configurations. + + Parameters: + emb_models (Union[List, Any]): A dictionary where keys are embedder names and values + are configurations for initializing the embedders. + + """ + + KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1} + + def __init__(self, **emb_models: Union[List, Any]): + super().__init__() + self.embedders = nn.ModuleDict() + for n, (emb_name, embconfig) in enumerate(emb_models.items()): + embedder = instantiate(embconfig.obj) + assert isinstance( + embedder, BaseConditionEntry + ), f"embedder model {embedder.__class__.__name__} has to inherit from BaseConditionEntry" + embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0) + + if hasattr(embconfig, "input_key"): + embedder.input_key = embconfig.input_key + elif hasattr(embconfig, "input_keys"): + embedder.input_keys = embconfig.input_keys + else: + raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") + + log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}") + self.embedders[emb_name] = embedder + + @abstractmethod + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> Any: + """Should be implemented in subclasses to handle conditon datatype""" + raise NotImplementedError + + def _forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> Dict: + """ + Processes the input batch through all configured embedders, applying conditional dropout rates if specified. + Output tensors for each key are concatenated along the dimensions specified in KEY2DIM. + + Parameters: + batch (Dict): The input data batch to process. + override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates + per embedder key. + + Returns: + Dict: A dictionary of output tensors concatenated by specified dimensions. + + Note: + In case the network code is sensitive to the order of concatenation, you can either control the order via \ + config file or make sure the embedders return a unique key for each output. + """ + output = defaultdict(list) + if override_dropout_rate is None: + override_dropout_rate = {} + + # make sure emb_name in override_dropout_rate is valid + for emb_name in override_dropout_rate.keys(): + assert emb_name in self.embedders, f"invalid name found {emb_name}" + + for emb_name, embedder in self.embedders.items(): + with torch.no_grad(): + if hasattr(embedder, "input_key") and (embedder.input_key is not None): + emb_out = embedder( + embedder.random_dropout_input( + batch[embedder.input_key], override_dropout_rate.get(emb_name, None) + ) + ) + elif hasattr(embedder, "input_keys"): + emb_out = embedder( + *[ + embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k) + for k in embedder.input_keys + ] + ) + for k, v in emb_out.items(): + output[k].append(v) + # Concatenate the outputs + return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()} + + def get_condition_uncondition( + self, + data_batch: Dict, + ) -> Tuple[Any, Any]: + """ + Processes the provided data batch to generate conditioned and unconditioned outputs. + + This method manipulates dropout rates to simulate two scenarios: + 1. All conditions applied (conditioned) + 2. Conditions removed/reduced to minimum (unconditioned) + + This method sets dropout rates to zero for the conditioned scenario to fully apply + embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is + insignificant) to minimize embedder influences. + + Parameters: + data_batch (Dict): Input data batch containing all necessary information for + embedding processing. + + Returns: + Tuple[Any, Any]: A tuple containing: + - Outputs with all embedders fully applied (conditioned) + - Outputs with embedders minimized/not applied (unconditioned) + """ + cond_dropout_rates, dropout_rates = {}, {} + for emb_name, embedder in self.embedders.items(): + cond_dropout_rates[emb_name] = 0.0 + dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 + + condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) + un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) + return condition, un_condition + + def get_condition_with_negative_prompt( + self, + data_batch: Dict, + ) -> Tuple[Any, Any]: + """ + Similar functionality as get_condition_uncondition + But use negative prompts for unconditon + """ + cond_dropout_rates, uncond_dropout_rates = {}, {} + for emb_name, embedder in self.embedders.items(): + cond_dropout_rates[emb_name] = 0.0 + if isinstance(embedder, TextAttr): + uncond_dropout_rates[emb_name] = 0.0 + else: + uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 + + data_batch_neg_prompt = copy.deepcopy(data_batch) + if "neg_t5_text_embeddings" in data_batch_neg_prompt: + if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor): + data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"] + data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"] + + condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) + un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates) + + return condition, un_condition + + +@dataclass +class CosmosCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + padding_mask: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +class VideoConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> BaseVideoCondition: + output = super()._forward(batch, override_dropout_rate) + return BaseVideoCondition(**output) + + +class VideoExtendConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoExtendCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoExtendCondition(**output) diff --git a/cosmos_predict1/diffusion/config/__init__.py b/cosmos_predict1/diffusion/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/diffusion/config/base/__init__.py b/cosmos_predict1/diffusion/config/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/diffusion/config/base/conditioner.py b/cosmos_predict1/diffusion/config/base/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..43e09aa039d2a71e30a4285970bcbe658e19df4d --- /dev/null +++ b/cosmos_predict1/diffusion/config/base/conditioner.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional + +import attrs +import torch + +from cosmos_predict1.diffusion.conditioner import BaseConditionEntry, TextAttr, VideoConditioner, VideoExtendConditioner +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class TextConfig: + obj: LazyDict = L(TextAttr)() # No arguments + dropout_rate: float = 0.2 + input_keys: List[str] = attrs.field(factory=lambda: ["t5_text_embeddings", "t5_text_mask"]) + + +class BooleanFlag(BaseConditionEntry): + def __init__(self, output_key: Optional[str] = None): + super().__init__() + self.output_key = output_key + + def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + del args, kwargs + key = self.output_key if self.output_key else self.input_key + return {key: self.flag} + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + del key + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + self.flag = torch.bernoulli((1.0 - dropout_rate) * torch.ones(1)).bool().to(device=in_tensor.device) + return in_tensor + + +class ReMapkey(BaseConditionEntry): + def __init__(self, output_key: Optional[str] = None, dtype: Optional[str] = None): + super().__init__() + self.output_key = output_key + self.dtype = { + None: None, + "float": torch.float32, + "bfloat16": torch.bfloat16, + "half": torch.float16, + "float16": torch.float16, + "int": torch.int32, + "long": torch.int64, + }[dtype] + + def forward(self, element: torch.Tensor) -> Dict[str, torch.Tensor]: + key = self.output_key if self.output_key else self.input_key + if isinstance(element, torch.Tensor): + element = element.to(dtype=self.dtype) + return {key: element} + + +class FrameRepeatAttr(BaseConditionEntry): + def __init__(self): + super().__init__() + + def forward(self, frame_repeat: torch.Tensor) -> Dict[str, torch.Tensor]: + return { + "frame_repeat": frame_repeat / 10.0, + } + + +@attrs.define(slots=False) +class FPSConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `fps`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="fps", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "fps" + + +@attrs.define(slots=False) +class PaddingMaskConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `padding_mask`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="padding_mask", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "padding_mask" + + +@attrs.define(slots=False) +class ImageSizeConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `image_size`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="image_size", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "image_size" + + +@attrs.define(slots=False) +class NumFramesConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `num_frames`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="num_frames", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "num_frames" + + +@attrs.define(slots=False) +class FrameRepeatConfig: + """ + Remap and process key from the input dictionary to the output dictionary. For `frame_repeat`. + """ + + obj: LazyDict = L(FrameRepeatAttr)() + dropout_rate: float = 0.0 + input_key: str = "frame_repeat" + + +@attrs.define(slots=False) +class VideoCondBoolConfig: + obj: LazyDict = L(BooleanFlag)(output_key="video_cond_bool") + dropout_rate: float = 0.2 + input_key: str = "fps" # This is a placeholder, we never use this value + # Config below are for long video generation only + compute_loss_for_condition_region: bool = False # Compute loss for condition region + + # How to sample condition region during training. "first_random_n" set the first n frames to be condition region, n is random, "random" set the condition region to be random, + condition_location: str = "first_random_n" + random_conditon_rate: float = 0.5 # The rate to sample the condition region randomly + first_random_n_num_condition_t_max: int = 4 # The maximum number of frames to sample as condition region, used when condition_location is "first_random_n" + first_random_n_num_condition_t_min: int = 0 # The minimum number of frames to sample as condition region, used when condition_location is "first_random_n" + + # How to dropout value of the conditional input frames + cfg_unconditional_type: str = "zero_condition_region_condition_mask" # Unconditional type. "zero_condition_region_condition_mask" set the input to zero for condition region, "noise_x_condition_region" set the input to x_t, same as the base model + + # How to corrupt the condition region + apply_corruption_to_condition_region: str = "noise_with_sigma" # Apply corruption to condition region, option: "gaussian_blur", "noise_with_sigma", "clean" (inference), "noise_with_sigma_fixed" (inference) + # Inference only option: list of sigma value for the corruption at different chunk id, used when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" + apply_corruption_to_condition_region_sigma_value: list[float] = [0.001, 0.2] + [ + 0.5 + ] * 10 # Sigma value for the corruption, used when apply_corruption_to_condition_region is "noise_with_sigma_fixed" + + # Add augment_sigma condition to the network + condition_on_augment_sigma: bool = False + # The following arguments is to match with previous implementation where we use train sde to sample augment sigma (with adjust video noise turn on) + augment_sigma_sample_p_mean: float = 0.0 # Mean of the augment sigma + augment_sigma_sample_p_std: float = 1.0 # Std of the augment sigma + augment_sigma_sample_multiplier: float = 4.0 # Multipler of augment sigma + + # Add pose condition to the network + add_pose_condition: bool = False + + # Sample PPP... from IPPP... sequence + sample_tokens_start_from_p_or_i: bool = False + + # Normalize the input condition latent + normalize_condition_latent: bool = False + + +@attrs.define(slots=False) +class LatentConditionConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `latent condition`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="latent_condition", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "latent_condition" + + +@attrs.define(slots=False) +class LatentConditionSigmaConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `latent condition`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="latent_condition_sigma", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "latent_condition_sigma" + + +BaseVideoConditionerConfig: LazyDict = L(VideoConditioner)( + text=TextConfig(), +) + +VideoConditionerFpsSizePaddingConfig: LazyDict = L(VideoConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), +) + +VideoExtendConditionerConfig: LazyDict = L(VideoExtendConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + video_cond_bool=VideoCondBoolConfig(), +) + +VideoConditionerFpsSizePaddingFrameRepeatConfig: LazyDict = L(VideoConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + frame_repeat=FrameRepeatConfig(), +) + +VideoExtendConditionerFrameRepeatConfig: LazyDict = L(VideoExtendConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + video_cond_bool=VideoCondBoolConfig(), + frame_repeat=FrameRepeatConfig(), +) diff --git a/cosmos_predict1/diffusion/config/base/model.py b/cosmos_predict1/diffusion/config/base/model.py new file mode 100644 index 0000000000000000000000000000000000000000..819ddb4a3b5ea63887ba3697b085b70fee319874 --- /dev/null +++ b/cosmos_predict1/diffusion/config/base/model.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional + +import attrs + +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class DefaultModelConfig: + tokenizer: LazyDict = None + conditioner: LazyDict = None + net: LazyDict = None + sigma_data: float = 0.5 + precision: str = "bfloat16" + input_data_key: str = "video" # key to fetch input data from data_batch + latent_shape: List[int] = [16, 24, 44, 80] # 24 corresponig to 136 frames + input_image_key: str = "images_1024" + adjust_video_noise: bool = False # Added field with default value + context_parallel_size: int = 1 # Added field with default value + # `num_latents_to_drop` is a flag that helps satisfy (1I,N*P,1I) latents setup. + # Since the tokenizer is causal and has the `T+1` input frames setup, it's + # challenging to encode arbitrary number of frames. To circumvent this, + # we sample as many frames, run the tokenizer twice, and discard the last + # chunk's P-latents, ensuring the requirement: I-latents for the input frames + # and P-latent for the-to-be-predicted in-between frames. + # By default, this flag does not have any effect. + num_latents_to_drop: int = 0 # number of P-latents to discard after encoding + + sde: Optional[Dict] = None + vae: Optional[Dict] = None # Add this line to include the vae field + peft_control: LazyDict | None = None + frame_buffer_max: Optional[int] = 1 + + +@attrs.define(slots=False) +class LatentDiffusionDecoderModelConfig(DefaultModelConfig): + tokenizer_corruptor: LazyDict = None + latent_corruptor: LazyDict = None + pixel_corruptor: LazyDict = None + diffusion_decoder_cond_sigma_low: float = None + diffusion_decoder_cond_sigma_high: float = None + diffusion_decoder_corrupt_prob: float = None + condition_on_tokenizer_corruptor_token: bool = False + + +@attrs.define(slots=False) +class MultiviewModelConfig(DefaultModelConfig): + n_views: int = 4 diff --git a/cosmos_predict1/diffusion/config/base/net.py b/cosmos_predict1/diffusion/config/base/net.py new file mode 100644 index 0000000000000000000000000000000000000000..6272aa996abc1cf1b394dc1f9ac1b8d2555ca554 --- /dev/null +++ b/cosmos_predict1/diffusion/config/base/net.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from cosmos_predict1.diffusion.networks.general_dit import GeneralDIT +from cosmos_predict1.diffusion.networks.general_dit_multiview import MultiviewGeneralDIT +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +FADITV2Config: LazyDict = L(GeneralDIT)( + max_img_h=240, + max_img_w=240, + max_frames=128, + in_channels=16, + out_channels=16, + patch_spatial=2, + patch_temporal=1, + model_channels=4096, + block_config="FA-CA-MLP", + num_blocks=28, + num_heads=32, + concat_padding_mask=True, + pos_emb_cls="rope3d", + pos_emb_learnable=False, + pos_emb_interpolation="crop", + block_x_format="THWBD", + affline_emb_norm=True, + use_adaln_lora=True, + adaln_lora_dim=256, +) + + +FADITV2_14B_Config = copy.deepcopy(FADITV2Config) +FADITV2_14B_Config.model_channels = 5120 +FADITV2_14B_Config.num_heads = 40 +FADITV2_14B_Config.num_blocks = 36 + + +FADITV2_Multiview_Config: LazyDict = L(MultiviewGeneralDIT)( + max_img_h=240, + max_img_w=240, + max_frames=128, + in_channels=16, + out_channels=16, + patch_spatial=2, + patch_temporal=1, + model_channels=4096, + block_config="FA-CA-MLP", + num_blocks=28, + num_heads=32, + concat_padding_mask=True, + pos_emb_cls="rope3d", + pos_emb_learnable=False, + pos_emb_interpolation="crop", + block_x_format="THWBD", + affline_emb_norm=True, + use_adaln_lora=True, + adaln_lora_dim=256, + n_views=6, + view_condition_dim=6, + add_repeat_frame_embedding=True, + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=1.0, +) diff --git a/cosmos_predict1/diffusion/config/base/tokenizer.py b/cosmos_predict1/diffusion/config/base/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..eb73d34ea58994feb9738d15b8738d495cea0b8a --- /dev/null +++ b/cosmos_predict1/diffusion/config/base/tokenizer.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import omegaconf + +from cosmos_predict1.diffusion.module.pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer +from cosmos_predict1.utils.lazy_config import LazyCall as L + +TOKENIZER_OPTIONS = {} + + +def tokenizer_register(key): + def decorator(func): + TOKENIZER_OPTIONS[key] = func + return func + + return decorator + + +@tokenizer_register("cosmos_diffusion_tokenizer_comp8x8x8") +def get_cosmos_diffusion_tokenizer_comp8x8x8(resolution: str, chunk_duration: int) -> omegaconf.dictconfig.DictConfig: + assert resolution in ["720"] + + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 8 + + return L(JointImageVideoSharedJITTokenizer)( + video_vae=L(VideoJITTokenizer)( + name="cosmos_predict1_tokenizer", + latent_ch=16, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + temporal_compression_factor=temporal_compression_factor, + spatial_compression_factor=spatial_compression_factor, + spatial_resolution=resolution, + ), + image_vae=L(JITVAE)( + name="cosmos_predict1_tokenizer", + latent_ch=16, + is_image=False, + is_bf16=True, + ), + name="cosmos_predict1_tokenizer", + latent_ch=16, + ) diff --git a/cosmos_predict1/diffusion/config/config.py b/cosmos_predict1/diffusion/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..efb24aceb14b94ec6d47d8b0e539a56d3e957679 --- /dev/null +++ b/cosmos_predict1/diffusion/config/config.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from cosmos_predict1.diffusion.config.base.model import DefaultModelConfig +from cosmos_predict1.diffusion.config.registry import register_configs +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config_helper import import_all_modules_from_package + + +@attrs.define(slots=False) +class Config(config.Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"net": None}, + {"conditioner": "add_fps_image_size_padding_mask"}, + {"tokenizer": "tokenizer"}, + {"experiment": None}, + ] + ) + + +def make_config(): + c = Config( + model=DefaultModelConfig(), + ) + + # Specifying values through instances of attrs + c.job.project = "cosmos_diffusion" + c.job.group = "inference" + + # Call this function to register config groups for advanced overriding. + register_configs() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.diffusion.config.inference", reload=True) + return c diff --git a/cosmos_predict1/diffusion/config/inference/__init__.py b/cosmos_predict1/diffusion/config/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-gen3c.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-gen3c.py new file mode 100644 index 0000000000000000000000000000000000000000..e48c8c133494af1304d83ce9b9692025fc3b1b2a --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-gen3c.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +GEN3C_Cosmos_7B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + conditioner=dict(video_cond_bool=dict()), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=2.0, + in_channels=16 + 16 * 4 + 1 # 16: video_latent, 16 * 4: (warped_frames + warped_frames_mask) * buffer 2, 1: mask + ), + frame_buffer_max=2, + ), + job=dict(group="Gen3c", name="GEN3C_Cosmos_7B"), + ) +) + +cs = ConfigStore.instance() +for _item in [ + GEN3C_Cosmos_7B, +]: + cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..57eda7cd4d5dea477b7ac7175cebf9b96b1f2133 --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +Cosmos_Predict1_Text2World_7B_Multiview: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B", + {"override /net": "faditv2_multiview_7b"}, + {"override /conditioner": "add_fps_image_size_padding_mask_frame_repeat"}, + "_self_", + ], + job=dict( + group="Text2World", + name="Cosmos_Predict1_Text2World_7B_Multiview", + ), + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + tokenizer=dict( + video_vae=dict( + pixel_chunk_duration=57, + ) + ), + ), + ) +) + + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_Predict1_Text2World_7B_Multiview["job"]["name"], + node=Cosmos_Predict1_Text2World_7B_Multiview, +) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world.py new file mode 100644 index 0000000000000000000000000000000000000000..13d709115df426479da991e3d827e1bb8c434fc7 --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.utils.peft.lora_config import get_fa_ca_qv_lora_config +from cosmos_predict1.utils.lazy_config import LazyDict + +Cosmos_Predict1_Text2World_7B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + job=dict( + group="Text2World", + name="Cosmos_Predict1_Text2World_7B", + ), + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + net=dict( + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=2.0, + ), + ), + ) +) + +Cosmos_Predict1_Text2World_14B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_14b"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + job=dict( + group="Text2World", + name="Cosmos_Predict1_Text2World_14B", + ), + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + net=dict( + rope_h_extrapolation_ratio=2.0, + rope_t_extrapolation_ratio=2.0, + rope_w_extrapolation_ratio=2.0, + extra_h_extrapolation_ratio=2.0, + extra_t_extrapolation_ratio=2.0, + extra_w_extrapolation_ratio=2.0, + ), + ), + ) +) + +Cosmos_Predict1_Text2World_7B_Post_trained: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Text2World_7B_Post_trained", + ), + ) +) + +Cosmos_Predict1_Text2World_14B_Post_trained: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_14B", + ], + job=dict( + name="Cosmos_Predict1_Text2World_14B_Post_trained", + ), + ) +) + +Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_80gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_80gb", + ), + model=dict( + latent_shape=[ # 384x384 resolution + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + tokenizer=dict( + video_vae=dict(pixel_chunk_duration=121, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Text2World_7B_Post_trained_8gpu_40gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Text2World_7B_Post_trained_8gpu_40gb", + ), + model=dict( + latent_shape=[ # 384x384 resolution + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + tokenizer=dict( + video_vae=dict(pixel_chunk_duration=33, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_40gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_40gb", + ), + model=dict( + latent_shape=[ # 384x384 resolution + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + tokenizer=dict( + video_vae=dict(pixel_chunk_duration=17, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Text2World_7B_Post_trained_lora: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B_Post_trained", + ], + job=dict( + name="Cosmos_Predict1_Text2World_7B_Post_trained_lora", + ), + model=dict( + peft_control=get_fa_ca_qv_lora_config(first_nblocks=27, rank=8, scale=1), + ), + ) +) + +cs = ConfigStore.instance() + +for _item in [ + Cosmos_Predict1_Text2World_7B, + Cosmos_Predict1_Text2World_14B, + Cosmos_Predict1_Text2World_7B_Post_trained, + Cosmos_Predict1_Text2World_14B_Post_trained, + Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_80gb, + Cosmos_Predict1_Text2World_7B_Post_trained_8gpu_40gb, + Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_40gb, + Cosmos_Predict1_Text2World_7B_Post_trained_lora, +]: + cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..7c6cfdcee8b3d343a8897fd99836370b7c8a7dc4 --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.networks.general_dit_video_conditioned_multiview import MultiviewVideoExtendGeneralDIT +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +Cosmos_Predict1_Video2World_7B_Multiview: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Text2World_7B_Multiview", + {"override /conditioner": "video_cond_frame_repeat"}, + "_self_", + ], + job=dict( + group="Text2World", + name="Cosmos_Predict1_Video2World_7B_Multiview", + ), + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + net=L(MultiviewVideoExtendGeneralDIT)( + n_views=6, + view_condition_dim=6, + add_repeat_frame_embedding=True, + ), + conditioner=dict(video_cond_bool=dict()), + ), + ) +) + + +cs = ConfigStore.instance() +cs.store( + group="experiment", + package="_global_", + name=Cosmos_Predict1_Video2World_7B_Multiview["job"]["name"], + node=Cosmos_Predict1_Video2World_7B_Multiview, +) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..820157e52d3dd731d8f024cdeb5437e03471a999 --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT +from cosmos_predict1.diffusion.training.utils.peft.lora_config import get_fa_ca_qv_lora_config +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +Cosmos_Predict1_Video2World_7B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + conditioner=dict(video_cond_bool=dict()), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=2.0, + ), + ), + job=dict(group="Video2World", name="Cosmos_Predict1_Video2World_7B"), + ) +) + + +Cosmos_Predict1_Video2World_14B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_14b"}, + {"override /conditioner": "video_cond"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + model=dict( + latent_shape=[ + 16, + 16, + 88, + 160, + ], + conditioner=dict(video_cond_bool=dict()), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=2.0, + rope_t_extrapolation_ratio=2.0, + rope_w_extrapolation_ratio=2.0, + extra_h_extrapolation_ratio=2.0, + extra_t_extrapolation_ratio=2.0, + extra_w_extrapolation_ratio=2.0, + ), + ), + job=dict(group="Video2World", name="Cosmos_Predict1_Video2World_14B"), + ) +) + +Cosmos_Predict1_Video2World_7B_Post_trained: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Video2World_7B_Post_trained", + ), + ) +) + +Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_80gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_80gb", + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + tokenizer=dict( + video_vae=dict(pixel_chunk_duration=121, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Video2World_7B_Post_trained_8gpu_40gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Video2World_7B_Post_trained_8gpu_40gb", + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + tokenizer=dict( + video_vae=dict(pixel_chunk_duration=25, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_40gb: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_7B", + ], + job=dict( + name="Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_40gb", + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 24, # Latent height dim + 24, # Latent width dim + ], + tokenizer=dict( + # video_vae=dict(pixel_chunk_duration=17, spatial_resolution="384"), + video_vae=dict(pixel_chunk_duration=25, spatial_resolution="384"), + ), + ), + ) +) + +Cosmos_Predict1_Video2World_14B_Post_trained: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_14B", + ], + job=dict( + name="Cosmos_Predict1_Video2World_14B_Post_trained", + ), + ) +) + +Cosmos_Predict1_Video2World_7B_Post_trained_lora: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_Video2World_7B_Post_trained", + ], + job=dict( + name="Cosmos_Predict1_Video2World_7B_Post_trained_lora", + ), + model=dict( + peft_control=get_fa_ca_qv_lora_config(first_nblocks=27, rank=8, scale=1), + ), + ) +) + +cs = ConfigStore.instance() +for _item in [ + Cosmos_Predict1_Video2World_7B, + Cosmos_Predict1_Video2World_14B, + Cosmos_Predict1_Video2World_7B_Post_trained, + Cosmos_Predict1_Video2World_14B_Post_trained, + Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_80gb, + Cosmos_Predict1_Video2World_7B_Post_trained_8gpu_40gb, + Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_40gb, + Cosmos_Predict1_Video2World_7B_Post_trained_lora, +]: + cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-world-interpolator.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-world-interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..3580f8ef8947241412551242d21da32e012a6cc2 --- /dev/null +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-world-interpolator.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT +from cosmos_predict1.diffusion.training.modules.edm_sde import EDMSDE +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +Cosmos_Predict1_WorldInterpolator_7B: LazyDict = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, + "_self_", + ], + model=dict( + sde=L(EDMSDE)( + p_mean=0.0, + p_std=1.0, + sigma_max=80, + sigma_min=0.0002, + ), + input_image_key="images_1024", + latent_shape=[ + 16, + 4, + 88, + 160, + ], + tokenizer=dict( + video_vae=dict( + pixel_chunk_duration=9, + ) + ), + vae=dict( # Added VAE field + pixel_chunk_duration=9, + latent_ch=16, + ), + adjust_video_noise=True, + num_latents_to_drop=1, + context_parallel_size=1, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_and_last_1", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + apply_corruption_to_condition_region_sigma_value=[0.001], + ), + text=dict( + dropout_rate=0.5, + ), + ), + net=L(VideoExtendGeneralDIT)( + extra_per_block_abs_pos_emb=True, + rope_h_extrapolation_ratio=1.0, + rope_w_extrapolation_ratio=1.0, + rope_t_extrapolation_ratio=2.0, + extra_per_block_abs_pos_emb_type="learnable", + ), + ), + job=dict(group="WorldInterpolator", name="Cosmos_Predict1_WorldInterpolator_7B"), + ) +) + +Cosmos_Predict1_WorldInterpolator_7B_Post_trained: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/Cosmos_Predict1_WorldInterpolator_7B", + ], + job=dict( + name="Cosmos_Predict1_WorldInterpolator_7B_Post_trained", + ), + ) +) + + +cs = ConfigStore.instance() +for _item in [ + Cosmos_Predict1_WorldInterpolator_7B, + Cosmos_Predict1_WorldInterpolator_7B_Post_trained, +]: + cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item) diff --git a/cosmos_predict1/diffusion/config/registry.py b/cosmos_predict1/diffusion/config/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..8a39525d0b03116da3ac08c55fa729f85cdc7c64 --- /dev/null +++ b/cosmos_predict1/diffusion/config/registry.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.config.base.conditioner import ( + BaseVideoConditionerConfig, + VideoConditionerFpsSizePaddingConfig, + VideoConditionerFpsSizePaddingFrameRepeatConfig, + VideoExtendConditionerConfig, + VideoExtendConditionerFrameRepeatConfig, +) +from cosmos_predict1.diffusion.config.base.net import FADITV2_14B_Config, FADITV2_Multiview_Config, FADITV2Config +from cosmos_predict1.diffusion.config.base.tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 + + +def register_net(cs): + cs.store( + group="net", + package="model.net", + name="faditv2_7b", + node=FADITV2Config, + ) + cs.store( + group="net", + package="model.net", + name="faditv2_14b", + node=FADITV2_14B_Config, + ) + cs.store( + group="net", + package="model.net", + name="faditv2_multiview_7b", + node=FADITV2_Multiview_Config, + ) + + +def register_conditioner(cs): + cs.store( + group="conditioner", + package="model.conditioner", + name="basic", + node=BaseVideoConditionerConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="add_fps_image_size_padding_mask", + node=VideoConditionerFpsSizePaddingConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="video_cond", + node=VideoExtendConditionerConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="add_fps_image_size_padding_mask_frame_repeat", + node=VideoConditionerFpsSizePaddingFrameRepeatConfig, + ) + cs.store( + group="conditioner", + package="model.conditioner", + name="video_cond_frame_repeat", + node=VideoExtendConditionerFrameRepeatConfig, + ) + + +def register_tokenizer(cs): + cs.store( + group="tokenizer", + package="model.tokenizer", + name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624", + node=get_cosmos_diffusion_tokenizer_comp8x8x8(resolution="720", chunk_duration=121), + ) + + +def register_configs(): + cs = ConfigStore.instance() + + register_net(cs) + register_conditioner(cs) + register_tokenizer(cs) diff --git a/cosmos_predict1/diffusion/functional/batch_ops.py b/cosmos_predict1/diffusion/functional/batch_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a72b24097f7cc9e7e6a8b324919455131bf84d47 --- /dev/null +++ b/cosmos_predict1/diffusion/functional/batch_ops.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Functions for performing operations with broadcasting to the right axis +# +# Example +# input1: tensor of size (N1, N2) +# input2: tensor of size (N1, N2, N3, N4) +# batch_mul(input1, input2) = input1[:, :, None, None] * input2 +# +# If the common dimensions don't match, we raise an assertion error. + +from torch import Tensor + + +def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: + ndims1 = x.ndim + ndims2 = y.ndim + + common_ndims = min(ndims1, ndims2) + for axis in range(common_ndims): + assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) + + if ndims1 < ndims2: + x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) + elif ndims2 < ndims1: + y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) + + return x, y + + +def batch_add(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x + y + + +def batch_mul(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x * y + + +def batch_sub(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x - y + + +def batch_div(x: Tensor, y: Tensor) -> Tensor: + x, y = common_broadcast(x, y) + return x / y diff --git a/cosmos_predict1/diffusion/functional/multi_step.py b/cosmos_predict1/diffusion/functional/multi_step.py new file mode 100644 index 0000000000000000000000000000000000000000..76cb57aea441bfa36ddf7eeac9be75b40761cc5b --- /dev/null +++ b/cosmos_predict1/diffusion/functional/multi_step.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Impl of multistep methods to solve the ODE in the diffusion model. +""" + +from typing import Callable, List, Tuple + +import torch + +from cosmos_predict1.diffusion.functional.runge_kutta import reg_x0_euler_step, res_x0_rk2_step + + +def order2_fn( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + impl the second order multistep method in https://arxiv.org/pdf/2308.02157 + Adams Bashforth approach! + """ + if x0_preds: + x0_s1, s1 = x0_preds[0] + x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1) + else: + x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0] + return x_t, [(x0_s, s)] + + +# key: method name, value: method function +# key: order + algorithm name +MULTISTEP_FNs = { + "2ab": order2_fn, +} + + +def get_multi_step_fn(name: str) -> Callable: + if name in MULTISTEP_FNs: + return MULTISTEP_FNs[name] + methods = "\n\t".join(MULTISTEP_FNs.keys()) + raise RuntimeError("Only support multistep method\n" + methods) + + +def is_multi_step_fn_supported(name: str) -> bool: + """ + Check if the multistep method is supported. + """ + return name in MULTISTEP_FNs diff --git a/cosmos_predict1/diffusion/functional/runge_kutta.py b/cosmos_predict1/diffusion/functional/runge_kutta.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5841db9fb4ad463f9411206c953f405a7fad50 --- /dev/null +++ b/cosmos_predict1/diffusion/functional/runge_kutta.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Tuple + +import torch + +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul + + +def phi1(t: torch.Tensor) -> torch.Tensor: + """ + Compute the first order phi function: (exp(t) - 1) / t. + + Args: + t: Input tensor. + + Returns: + Tensor: Result of phi1 function. + """ + input_dtype = t.dtype + t = t.to(dtype=torch.float64) + return (torch.expm1(t) / t).to(dtype=input_dtype) + + +def phi2(t: torch.Tensor) -> torch.Tensor: + """ + Compute the second order phi function: (phi1(t) - 1) / t. + + Args: + t: Input tensor. + + Returns: + Tensor: Result of phi2 function. + """ + input_dtype = t.dtype + t = t.to(dtype=torch.float64) + return ((phi1(t) - 1.0) / t).to(dtype=input_dtype) + + +def res_x0_rk2_step( + x_s: torch.Tensor, + t: torch.Tensor, + s: torch.Tensor, + x0_s: torch.Tensor, + s1: torch.Tensor, + x0_s1: torch.Tensor, +) -> torch.Tensor: + """ + Perform a residual-based 2nd order Runge-Kutta step. + + Args: + x_s: Current state tensor. + t: Target time tensor. + s: Current time tensor. + x0_s: Prediction at current time. + s1: Intermediate time tensor. + x0_s1: Prediction at intermediate time. + + Returns: + Tensor: Updated state tensor. + + Raises: + AssertionError: If step size is too small. + """ + s = -torch.log(s) + t = -torch.log(t) + m = -torch.log(s1) + + dt = t - s + assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" + assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" + + c2 = (m - s) / dt + phi1_val, phi2_val = phi1(-dt), phi2(-dt) + + # Handle edge case where t = s = m + b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0) + b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0) + + return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1)) + + +def reg_x0_euler_step( + x_s: torch.Tensor, + s: torch.Tensor, + t: torch.Tensor, + x0_s: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a regularized Euler step based on x0 prediction. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_s: Prediction at current time. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current prediction. + """ + coef_x0 = (s - t) / s + coef_xs = t / s + return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s + + +def reg_eps_euler_step( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, eps_s: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a regularized Euler step based on epsilon prediction. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + eps_s: Epsilon prediction at current time. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current x0 prediction. + """ + return x_s + batch_mul(eps_s, t - s), x_s + batch_mul(eps_s, 0 - s) + + +def rk1_euler( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a first-order Runge-Kutta (Euler) step. + + Recommended for diffusion models with guidance or model undertrained + Usually more stable at the cost of a bit slower convergence. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. + """ + x0_s = x0_fn(x_s, s) + return reg_x0_euler_step(x_s, s, t, x0_s) + + +def rk2_mid_stable( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a stable second-order Runge-Kutta (midpoint) step. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. + """ + s1 = torch.sqrt(s * t) + x_s1, _ = rk1_euler(x_s, s, s1, x0_fn) + + x0_s1 = x0_fn(x_s1, s1) + return reg_x0_euler_step(x_s, s, t, x0_s1) + + +def rk2_mid(x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a second-order Runge-Kutta (midpoint) step. + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. + """ + s1 = torch.sqrt(s * t) + x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn) + + x0_s1 = x0_fn(x_s1, s1) + + return res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1), x0_s1 + + +def rk_2heun_naive( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a naive second-order Runge-Kutta (Heun's method) step. + Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis + Recommended for diffusion models without guidance and relative large NFE + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current state. + """ + x_t, x0_s = rk1_euler(x_s, s, t, x0_fn) + eps_s = batch_mul(1.0 / s, x_t - x0_s) + x0_t = x0_fn(x_t, t) + eps_t = batch_mul(1.0 / t, x_t - x0_t) + + avg_eps = (eps_s + eps_t) / 2 + + return reg_eps_euler_step(x_s, s, t, avg_eps) + + +def rk_2heun_edm( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a naive second-order Runge-Kutta (Heun's method) step. + Impl based no EDM second order Heun method + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current state. + """ + x_t, x0_s = rk1_euler(x_s, s, t, x0_fn) + x0_t = x0_fn(x_t, t) + + avg_x0 = (x0_s + x0_t) / 2 + + return reg_x0_euler_step(x_s, s, t, avg_x0) + + +def rk_3kutta_naive( + x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform a naive third-order Runge-Kutta step. + Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis + Recommended for diffusion models without guidance and relative large NFE + + Args: + x_s: Current state tensor. + s: Current time tensor. + t: Target time tensor. + x0_fn: Function to compute x0 prediction. + + Returns: + Tuple[Tensor, Tensor]: Updated state tensor and current state. + """ + c2, c3 = 0.5, 1.0 + a31, a32 = -1.0, 2.0 + b1, b2, b3 = 1.0 / 6, 4.0 / 6, 1.0 / 6 + + delta = t - s + + s1 = c2 * delta + s + s2 = c3 * delta + s + x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn) + eps_s = batch_mul(1.0 / s, x_s - x0_s) + x0_s1 = x0_fn(x_s1, s1) + eps_s1 = batch_mul(1.0 / s1, x_s1 - x0_s1) + + _eps = a31 * eps_s + a32 * eps_s1 + x_s2, _ = reg_eps_euler_step(x_s, s, s2, _eps) + + x0_s2 = x0_fn(x_s2, s2) + eps_s2 = batch_mul(1.0 / s2, x_s2 - x0_s2) + + avg_eps = b1 * eps_s + b2 * eps_s1 + b3 * eps_s2 + return reg_eps_euler_step(x_s, s, t, avg_eps) + + +# key : order + name +RK_FNs = { + "1euler": rk1_euler, + "2mid": rk2_mid, + "2mid_stable": rk2_mid_stable, + "2heun_edm": rk_2heun_edm, + "2heun_naive": rk_2heun_naive, + "3kutta_naive": rk_3kutta_naive, +} + + +def get_runge_kutta_fn(name: str) -> Callable: + """ + Get the specified Runge-Kutta function. + + Args: + name: Name of the Runge-Kutta method. + + Returns: + Callable: The specified Runge-Kutta function. + + Raises: + RuntimeError: If the specified method is not supported. + """ + if name in RK_FNs: + return RK_FNs[name] + methods = "\n\t".join(RK_FNs.keys()) + raise RuntimeError(f"Only support the following Runge-Kutta methods:\n\t{methods}") + + +def is_runge_kutta_fn_supported(name: str) -> bool: + """ + Check if the specified Runge-Kutta function is supported. + + Args: + name: Name of the Runge-Kutta method. + + Returns: + bool: True if the method is supported, False otherwise. + """ + return name in RK_FNs diff --git a/cosmos_predict1/diffusion/inference/cache_3d.py b/cosmos_predict1/diffusion/inference/cache_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d3697779cf22acebd3e4100f1254fdc7f33e33 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/cache_3d.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from einops import rearrange + +from cosmos_predict1.diffusion.inference.forward_warp_utils_pytorch import ( + forward_warp, + reliable_depth_mask_range_batch, + unproject_points, +) +from cosmos_predict1.diffusion.inference.camera_utils import align_depth + +class Cache3D_Base: + def __init__( + self, + input_image, + input_depth, + input_w2c, + input_intrinsics, + input_mask=None, + input_format=None, + input_points=None, + weight_dtype=torch.float32, + is_depth=True, + device="cuda", + filter_points_threshold=1.0, + foreground_masking=False, + ): + """ + input_image: Tensor with varying dimensions. + input_format: List of dimension labels corresponding to input_image's dimensions. + E.g., ['B', 'C', 'H', 'W'], ['B', 'F', 'C', 'H', 'W'], etc. + """ + self.weight_dtype = weight_dtype + self.is_depth = is_depth + self.device = device + self.filter_points_threshold = filter_points_threshold + self.foreground_masking = foreground_masking + if input_format is None: + assert input_image.dim() == 4 + input_format = ["B", "C", "H", "W"] + + # Map dimension names to their indices in input_image + format_to_indices = {dim: idx for idx, dim in enumerate(input_format)} + input_shape = input_image.shape + if input_mask is not None: + input_image = torch.cat([input_image, input_mask], dim=format_to_indices.get("C")) + + # B (batch size), F (frame count), N dimensions: no aggregation during warping. + # Only broadcasting over F to match the target w2c. + # V: aggregate via concatenation or duster + B = input_shape[format_to_indices.get("B", 0)] if "B" in format_to_indices else 1 # batch + F = input_shape[format_to_indices.get("F", 0)] if "F" in format_to_indices else 1 # frame + N = input_shape[format_to_indices.get("N", 0)] if "N" in format_to_indices else 1 # buffer + V = input_shape[format_to_indices.get("V", 0)] if "V" in format_to_indices else 1 # view + H = input_shape[format_to_indices.get("H", 0)] if "H" in format_to_indices else None + W = input_shape[format_to_indices.get("W", 0)] if "W" in format_to_indices else None + + # Desired dimension order + desired_dims = ["B", "F", "N", "V", "C", "H", "W"] + + # Build permute order based on input_format + permute_order = [] + for dim in desired_dims: + idx = format_to_indices.get(dim) + if idx is not None: + permute_order.append(idx) + else: + # Placeholder for dimensions to be added later + permute_order.append(None) + + # Remove None values for permute operation + permute_indices = [idx for idx in permute_order if idx is not None] + input_image = input_image.permute(*permute_indices) + + # Insert dimensions of size 1 where necessary + for i, idx in enumerate(permute_order): + if idx is None: + input_image = input_image.unsqueeze(i) + + # Now input_image has the shape B x F x N x V x C x H x W + if input_mask is not None: + self.input_image, self.input_mask = input_image[:, :, :, :, :3], input_image[:, :, :, :, 3:] + self.input_mask = self.input_mask.to("cpu") + else: + self.input_mask = None + self.input_image = input_image + self.input_image = self.input_image.to(weight_dtype).to("cpu") + + if input_points is not None: + self.input_points = input_points.reshape(B, F, N, V, H, W, 3).to("cpu") + self.input_depth = None + else: + input_depth = torch.nan_to_num(input_depth, nan=100) + input_depth = torch.clamp(input_depth, min=0, max=100) + if weight_dtype == torch.float16: + input_depth = torch.clamp(input_depth, max=70) + self.input_points = ( + self._compute_input_points( + input_depth.reshape(-1, 1, H, W), + input_w2c.reshape(-1, 4, 4), + input_intrinsics.reshape(-1, 3, 3), + ) + .to(weight_dtype) + .reshape(B, F, N, V, H, W, 3) + .to("cpu") + ) + self.input_depth = input_depth + + if self.filter_points_threshold < 1.0 and input_depth is not None: + input_depth = input_depth.reshape(-1, 1, H, W) + depth_mask = reliable_depth_mask_range_batch(input_depth, ratio_thresh=self.filter_points_threshold).reshape(B, F, N, V, 1, H, W) + if self.input_mask is None: + self.input_mask = depth_mask.to("cpu") + else: + self.input_mask = self.input_mask * depth_mask.to(self.input_mask.device) + self.boundary_mask = None + if foreground_masking: + input_depth = input_depth.reshape(-1, 1, H, W) + depth_mask = reliable_depth_mask_range_batch(input_depth) + self.boundary_mask = (~depth_mask).reshape(B, F, N, V, 1, H, W).to("cpu") + + def _compute_input_points(self, input_depth, input_w2c, input_intrinsics): + input_points = unproject_points( + input_depth, + input_w2c, + input_intrinsics, + is_depth=self.is_depth, + ) + return input_points + + def update_cache(self): + raise NotImplementedError + + def input_frame_count(self) -> int: + return self.input_image.shape[1] + + def render_cache(self, target_w2cs, target_intrinsics, render_depth=False, start_frame_idx=0): + bs, F_target, _, _ = target_w2cs.shape + + B, F, N, V, C, H, W = self.input_image.shape + assert bs == B + + target_w2cs = target_w2cs.reshape(B, F_target, 1, 4, 4).expand(B, F_target, N, 4, 4).reshape(-1, 4, 4) + target_intrinsics = ( + target_intrinsics.reshape(B, F_target, 1, 3, 3).expand(B, F_target, N, 3, 3).reshape(-1, 3, 3) + ) + + first_images = rearrange(self.input_image[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, C, H, W), "B F N V C H W-> (B F N) V C H W").to(self.device) + first_points = rearrange( + self.input_points[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, H, W, 3), "B F N V H W C-> (B F N) V H W C" + ).to(self.device) + first_masks = rearrange( + self.input_mask[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, 1, H, W), "B F N V C H W-> (B F N) V C H W" + ).to(self.device) if self.input_mask is not None else None + boundary_masks = rearrange( + self.boundary_mask.expand(B, F_target, N, V, 1, H, W), "B F N V C H W-> (B F N) V C H W" + ) if self.boundary_mask is not None else None + + if first_images.shape[1] == 1: + warp_chunk_size = 2 + rendered_warp_images = [] + rendered_warp_masks = [] + rendered_warp_depth = [] + rendered_warped_flows = [] + + first_images = first_images.squeeze(1) + first_points = first_points.squeeze(1) + first_masks = first_masks.squeeze(1) if first_masks is not None else None + for i in range(0, first_images.shape[0], warp_chunk_size): + ( + rendered_warp_images_chunk, + rendered_warp_masks_chunk, + rendered_warp_depth_chunk, + rendered_warped_flows_chunk, + ) = forward_warp( + first_images[i : i + warp_chunk_size], + mask1=first_masks[i : i + warp_chunk_size] if first_masks is not None else None, + depth1=None, + transformation1=None, + transformation2=target_w2cs[i : i + warp_chunk_size], + intrinsic1=target_intrinsics[i : i + warp_chunk_size], + intrinsic2=target_intrinsics[i : i + warp_chunk_size], + render_depth=render_depth, + world_points1=first_points[i : i + warp_chunk_size], + foreground_masking=self.foreground_masking, + boundary_mask=boundary_masks[i : i + warp_chunk_size, 0, 0] if boundary_masks is not None else None + ) + rendered_warp_images.append(rendered_warp_images_chunk) + rendered_warp_masks.append(rendered_warp_masks_chunk) + rendered_warp_depth.append(rendered_warp_depth_chunk) + rendered_warped_flows.append(rendered_warped_flows_chunk) + rendered_warp_images = torch.cat(rendered_warp_images, dim=0) + rendered_warp_masks = torch.cat(rendered_warp_masks, dim=0) + if render_depth: + rendered_warp_depth = torch.cat(rendered_warp_depth, dim=0) + rendered_warped_flows = torch.cat(rendered_warped_flows, dim=0) + + else: + raise NotImplementedError + + pixels = rearrange(rendered_warp_images, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N) + masks = rearrange(rendered_warp_masks, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N) + if render_depth: + pixels = rearrange(rendered_warp_depth, "(b f n) h w -> b f n h w", b=bs, f=F_target, n=N) + return pixels, masks + + +class Cache3D_Buffer(Cache3D_Base): + def __init__(self, frame_buffer_max=0, noise_aug_strength=0, generator=None, **kwargs): + super().__init__(**kwargs) + self.frame_buffer_max = frame_buffer_max + self.noise_aug_strength = noise_aug_strength + self.generator = generator + + def update_cache(self, new_image, new_depth, new_w2c, new_mask=None, new_intrinsics=None, depth_alignment=True, alignment_method="non_rigid"): # 3D cache + new_image = new_image.to(self.weight_dtype).to(self.device) + new_depth = new_depth.to(self.weight_dtype).to(self.device) + new_w2c = new_w2c.to(self.weight_dtype).to(self.device) + if new_intrinsics is not None: + new_intrinsics = new_intrinsics.to(self.weight_dtype).to(self.device) + + new_depth = torch.nan_to_num(new_depth, nan=1e4) + new_depth = torch.clamp(new_depth, min=0, max=1e4) + + if depth_alignment: + target_depth, target_mask = self.render_cache( + new_w2c.unsqueeze(1), new_intrinsics.unsqueeze(1), render_depth=True + ) + target_depth, target_mask = target_depth[:, :, 0], target_mask[:, :, 0] + if alignment_method == "rigid": + new_depth = ( + align_depth( + new_depth.squeeze(), + target_depth.squeeze(), + target_mask.bool().squeeze(), + ) + .reshape_as(new_depth) + .detach() + ) + elif alignment_method == "non_rigid": + with torch.enable_grad(): + new_depth = ( + align_depth( + new_depth.squeeze(), + target_depth.squeeze(), + target_mask.bool().squeeze(), + k=new_intrinsics.squeeze(), + c2w=torch.inverse(new_w2c.squeeze()), + alignment_method="non_rigid", + num_iters=100, + lambda_arap=0.1, + smoothing_kernel_size=3, + ) + .reshape_as(new_depth) + .detach() + ) + else: + raise NotImplementedError + new_points = unproject_points(new_depth, new_w2c, new_intrinsics, is_depth=self.is_depth).cpu() + new_image = new_image.cpu() + + if self.filter_points_threshold < 1.0: + B, F, N, V, C, H, W = self.input_image.shape + new_depth = new_depth.reshape(-1, 1, H, W) + depth_mask = reliable_depth_mask_range_batch(new_depth, ratio_thresh=self.filter_points_threshold).reshape(B, 1, H, W) + if new_mask is None: + new_mask = depth_mask.to("cpu") + else: + new_mask = new_mask * depth_mask.to(new_mask.device) + if new_mask is not None: + new_mask = new_mask.cpu() + if self.frame_buffer_max > 1: # newest frame first + if self.input_image.shape[2] < self.frame_buffer_max: + self.input_image = torch.cat([new_image[:, None, None, None], self.input_image], 2) + self.input_points = torch.cat([new_points[:, None, None, None], self.input_points], 2) + if self.input_mask is not None: + self.input_mask = torch.cat([new_mask[:, None, None, None], self.input_mask], 2) + else: + self.input_image[:, :, 0] = new_image[:, None, None] + self.input_points[:, :, 0] = new_points[:, None, None] + if self.input_mask is not None: + self.input_mask[:, :, 0] = new_mask[:, None, None] + else: + self.input_image = new_image[:, None, None, None] + self.input_points = new_points[:, None, None, None] + + + def render_cache( + self, + target_w2cs, + target_intrinsics, + render_depth: bool = False, + start_frame_idx: int = 0, # For consistency with Cache4D + ): + assert start_frame_idx == 0, "start_frame_idx must be 0 for Cache3D_Buffer" + + output_device = target_w2cs.device + target_w2cs = target_w2cs.to(self.weight_dtype).to(self.device) + target_intrinsics = target_intrinsics.to(self.weight_dtype).to(self.device) + pixels, masks = super().render_cache( + target_w2cs, target_intrinsics, render_depth + ) + if not render_depth: + noise = torch.randn(pixels.shape, generator=self.generator, device=pixels.device, dtype=pixels.dtype) + per_buffer_noise = ( + torch.arange(start=pixels.shape[2] - 1, end=-1, step=-1, device=pixels.device) + * self.noise_aug_strength + ) + pixels = pixels + noise * per_buffer_noise.reshape(1, 1, -1, 1, 1, 1) # B, F, N, C, H, W + return pixels.to(output_device), masks.to(output_device) + + +class Cache4D(Cache3D_Base): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def update_cache(self, **kwargs): + raise NotImplementedError + + def render_cache(self, target_w2cs, target_intrinsics, render_depth=False, start_frame_idx=0): + rendered_warp_images, rendered_warp_masks = super().render_cache(target_w2cs, target_intrinsics, render_depth, start_frame_idx) + return rendered_warp_images, rendered_warp_masks diff --git a/cosmos_predict1/diffusion/inference/camera_utils.py b/cosmos_predict1/diffusion/inference/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aa370636e2067e95d1332a7001b73e3583d35448 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/camera_utils.py @@ -0,0 +1,347 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import math +import torch.nn.functional as F +from .forward_warp_utils_pytorch import unproject_points + +def apply_transformation(Bx4x4, another_matrix): + B = Bx4x4.shape[0] + if another_matrix.dim() == 2: + another_matrix = another_matrix.unsqueeze(0).expand(B, -1, -1) # Make another_matrix compatible with batch size + transformed_matrix = torch.bmm(Bx4x4, another_matrix) # Shape: (B, 4, 4) + + return transformed_matrix + + +def look_at_matrix(camera_pos, target, invert_pos=True): + """Creates a 4x4 look-at matrix, keeping the camera pointing towards a target.""" + forward = (target - camera_pos).float() + forward = forward / torch.norm(forward) + + up = torch.tensor([0.0, 1.0, 0.0], device=camera_pos.device) # assuming Y-up coordinate system + right = torch.cross(up, forward) + right = right / torch.norm(right) + up = torch.cross(forward, right) + + look_at = torch.eye(4, device=camera_pos.device) + look_at[0, :3] = right + look_at[1, :3] = up + look_at[2, :3] = forward + look_at[:3, 3] = (-camera_pos) if invert_pos else camera_pos + + return look_at + +def create_horizontal_trajectory( + world_to_camera_matrix, center_depth, positive=True, n_steps=13, distance=0.1, device="cuda", axis="x", camera_rotation="center_facing" +): + look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) + # Spiral motion key points + trajectory = [] + translation_positions = [] + initial_camera_pos = torch.tensor([0, 0, 0], device=device) + + for i in range(n_steps): + if axis == "x": # pos - right + x = i * distance * center_depth / n_steps * (1 if positive else -1) + y = 0 + z = 0 + elif axis == "y": # pos - down + x = 0 + y = i * distance * center_depth / n_steps * (1 if positive else -1) + z = 0 + elif axis == "z": # pos - in + x = 0 + y = 0 + z = i * distance * center_depth / n_steps * (1 if positive else -1) + else: + raise ValueError("Axis should be x, y or z") + + translation_positions.append(torch.tensor([x, y, z], device=device)) + + for pos in translation_positions: + camera_pos = initial_camera_pos + pos + if camera_rotation == "trajectory_aligned": + _look_at = look_at + pos * 2 + elif camera_rotation == "center_facing": + _look_at = look_at + elif camera_rotation == "no_rotation": + _look_at = look_at + pos + else: + raise ValueError("Camera rotation should be center_facing or trajectory_aligned") + view_matrix = look_at_matrix(camera_pos, _look_at) + trajectory.append(view_matrix) + trajectory = torch.stack(trajectory) + return apply_transformation(trajectory, world_to_camera_matrix) + + +def create_spiral_trajectory( + world_to_camera_matrix, + center_depth, + radius_x=0.03, + radius_y=0.02, + radius_z=0.0, + positive=True, + camera_rotation="center_facing", + n_steps=13, + device="cuda", + start_from_zero=True, + num_circles=1, +): + + look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) + + # Spiral motion key points + trajectory = [] + spiral_positions = [] + initial_camera_pos = torch.tensor([0, 0, 0], device=device) # world_to_camera_matrix[:3, 3].clone() + + example_scale = 1.0 + + theta_max = 2 * math.pi * num_circles + + for i in range(n_steps): + # theta = 2 * math.pi * i / (n_steps-1) # angle for each point + theta = theta_max * i / (n_steps - 1) # angle for each point + if start_from_zero: + x = radius_x * (math.cos(theta) - 1) * (1 if positive else -1) * (center_depth / example_scale) + else: + x = radius_x * (math.cos(theta)) * (center_depth / example_scale) + + y = radius_y * math.sin(theta) * (center_depth / example_scale) + z = radius_z * math.sin(theta) * (center_depth / example_scale) + spiral_positions.append(torch.tensor([x, y, z], device=device)) + + for pos in spiral_positions: + if camera_rotation == "center_facing": + view_matrix = look_at_matrix(initial_camera_pos + pos, look_at) + elif camera_rotation == "trajectory_aligned": + view_matrix = look_at_matrix(initial_camera_pos + pos, look_at + pos * 2) + elif camera_rotation == "no_rotation": + view_matrix = look_at_matrix(initial_camera_pos + pos, look_at + pos) + else: + raise ValueError("Camera rotation should be center_facing, trajectory_aligned or no_rotation") + trajectory.append(view_matrix) + trajectory = torch.stack(trajectory) + return apply_transformation(trajectory, world_to_camera_matrix) + + +def generate_camera_trajectory( + trajectory_type: str, + initial_w2c: torch.Tensor, # Shape: (4, 4) + initial_intrinsics: torch.Tensor, # Shape: (3, 3) + num_frames: int, + movement_distance: float, + camera_rotation: str, + center_depth: float = 1.0, + device: str = "cuda", +): + """ + Generates a sequence of camera poses (world-to-camera matrices) and intrinsics + for a specified trajectory type. + + Args: + trajectory_type: Type of trajectory (e.g., "left", "right", "up", "down", "zoom_in", "zoom_out"). + initial_w2c: Initial world-to-camera matrix (4x4 tensor or num_framesx4x4 tensor). + initial_intrinsics: Camera intrinsics matrix (3x3 tensor or num_framesx3x3 tensor). + num_frames: Number of frames (steps) in the trajectory. + movement_distance: Distance factor for the camera movement. + camera_rotation: Type of camera rotation ('center_facing', 'no_rotation', 'trajectory_aligned'). + center_depth: Depth of the center point the camera might focus on. + device: Computation device ("cuda" or "cpu"). + + Returns: + A tuple (generated_w2cs, generated_intrinsics): + - generated_w2cs: Batch of world-to-camera matrices for the trajectory (1, num_frames, 4, 4 tensor). + - generated_intrinsics: Batch of camera intrinsics for the trajectory (1, num_frames, 3, 3 tensor). + """ + if trajectory_type in ["clockwise", "counterclockwise"]: + new_w2cs_seq = create_spiral_trajectory( + world_to_camera_matrix=initial_w2c, + center_depth=center_depth, + n_steps=num_frames, + positive=trajectory_type == "clockwise", + device=device, + camera_rotation=camera_rotation, + radius_x=movement_distance, + radius_y=movement_distance, + ) + else: + if trajectory_type == "left": + positive = False + axis = "x" + elif trajectory_type == "right": + positive = True + axis = "x" + elif trajectory_type == "up": + positive = False # Assuming 'up' means camera moves in negative y direction if y points down + axis = "y" + elif trajectory_type == "down": + positive = True # Assuming 'down' means camera moves in positive y direction if y points down + axis = "y" + elif trajectory_type == "zoom_in": + positive = True # Assuming 'zoom_in' means camera moves in positive z direction (forward) + axis = "z" + elif trajectory_type == "zoom_out": + positive = False # Assuming 'zoom_out' means camera moves in negative z direction (backward) + axis = "z" + else: + raise ValueError(f"Unsupported trajectory type: {trajectory_type}") + + # Generate world-to-camera matrices using create_horizontal_trajectory + new_w2cs_seq = create_horizontal_trajectory( + world_to_camera_matrix=initial_w2c, + center_depth=center_depth, + n_steps=num_frames, + positive=positive, + axis=axis, + distance=movement_distance, + device=device, + camera_rotation=camera_rotation, + ) + + generated_w2cs = new_w2cs_seq.unsqueeze(0) # Shape: [1, num_frames, 4, 4] + if initial_intrinsics.dim() == 2: + generated_intrinsics = initial_intrinsics.unsqueeze(0).unsqueeze(0).repeat(1, num_frames, 1, 1) + else: + generated_intrinsics = initial_intrinsics.unsqueeze(0) + + return generated_w2cs, generated_intrinsics + + +def _align_inv_depth_to_depth( + source_inv_depth: torch.Tensor, + target_depth: torch.Tensor, + target_mask: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Apply affine transformation to align source inverse depth to target depth. + + Args: + source_inv_depth: Inverse depth map to be aligned. Shape: (H, W). + target_depth: Target depth map. Shape: (H, W). + target_mask: Mask of valid target pixels. Shape: (H, W). + + Returns: + Aligned Depth map. Shape: (H, W). + """ + target_inv_depth = 1.0 / target_depth + source_mask = source_inv_depth > 0 + target_depth_mask = target_depth > 0 + + if target_mask is None: + target_mask = target_depth_mask + else: + target_mask = torch.logical_and(target_mask > 0, target_depth_mask) + + # Remove outliers + outlier_quantiles = torch.tensor([0.1, 0.9], device=source_inv_depth.device) + + source_data_low, source_data_high = torch.quantile(source_inv_depth[source_mask], outlier_quantiles) + target_data_low, target_data_high = torch.quantile(target_inv_depth[target_mask], outlier_quantiles) + source_mask = (source_inv_depth > source_data_low) & (source_inv_depth < source_data_high) + target_mask = (target_inv_depth > target_data_low) & (target_inv_depth < target_data_high) + + mask = torch.logical_and(source_mask, target_mask) + + source_data = source_inv_depth[mask].view(-1, 1) + target_data = target_inv_depth[mask].view(-1, 1) + + ones = torch.ones((source_data.shape[0], 1), device=source_data.device) + source_data_h = torch.cat([source_data, ones], dim=1) + transform_matrix = torch.linalg.lstsq(source_data_h, target_data).solution + + scale, bias = transform_matrix[0, 0], transform_matrix[1, 0] + aligned_inv_depth = source_inv_depth * scale + bias + + return 1.0 / aligned_inv_depth + + +def align_depth( + source_depth: torch.Tensor, + target_depth: torch.Tensor, + target_mask: torch.Tensor, + k: torch.Tensor = None, + c2w: torch.Tensor = None, + alignment_method: str = "rigid", + num_iters: int = 100, + lambda_arap: float = 0.1, + smoothing_kernel_size: int = 3, +) -> torch.Tensor: + if alignment_method == "rigid": + source_inv_depth = 1.0 / source_depth + source_depth = _align_inv_depth_to_depth(source_inv_depth, target_depth, target_mask) + return source_depth + elif alignment_method == "non_rigid": + if k is None or c2w is None: + raise ValueError("Camera intrinsics (k) and camera-to-world matrix (c2w) are required for non-rigid alignment") + + source_inv_depth = 1.0 / source_depth + source_depth = _align_inv_depth_to_depth(source_inv_depth, target_depth, target_mask) + + # Initialize scale map + sc_map = torch.ones_like(source_depth).float().to(source_depth.device).requires_grad_(True) + optimizer = torch.optim.Adam(params=[sc_map], lr=0.001) + + # Unproject target depth + target_unprojected = unproject_points( + target_depth.unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions + c2w.unsqueeze(0), # Add batch dimension + k.unsqueeze(0), # Add batch dimension + is_depth=True, + mask=target_mask.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions + ).squeeze(0) # Remove batch dimension + + # Create smoothing kernel + smoothing_kernel = torch.ones( + (1, 1, smoothing_kernel_size, smoothing_kernel_size), + device=source_depth.device + ) / (smoothing_kernel_size**2) + + for _ in range(num_iters): + # Unproject scaled source depth + source_unprojected = unproject_points( + (source_depth * sc_map).unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions + c2w.unsqueeze(0), # Add batch dimension + k.unsqueeze(0), # Add batch dimension + is_depth=True, + mask=target_mask.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions + ).squeeze(0) # Remove batch dimension + + # Data loss + data_loss = torch.abs(source_unprojected[target_mask] - target_unprojected[target_mask]).mean() + + # Apply smoothing filter to sc_map + sc_map_reshaped = sc_map.unsqueeze(0).unsqueeze(0) + sc_map_smoothed = F.conv2d( + sc_map_reshaped, + smoothing_kernel, + padding=smoothing_kernel_size // 2 + ).squeeze(0).squeeze(0) + + # ARAP loss + arap_loss = torch.abs(sc_map_smoothed - sc_map).mean() + + # Total loss + loss = data_loss + lambda_arap * arap_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return source_depth * sc_map + else: + raise ValueError(f"Unsupported alignment method: {alignment_method}") diff --git a/cosmos_predict1/diffusion/inference/data_loader_utils.py b/cosmos_predict1/diffusion/inference/data_loader_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c47d0573fa548ce76819ff43ac9fbf5ebe9e98f1 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/data_loader_utils.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Data loading utilities for the distributed format: +- RGB from mp4 +- Depth from float16 numpy +- Camera data from float32 numpy +""" + +import os +import numpy as np +import torch +import cv2 +from pathlib import Path + + +def load_rgb_from_mp4(video_path): + """ + Load RGB video from mp4 file and convert to tensor. + + Args: + video_path: str, path to the mp4 file + + Returns: + torch.Tensor: RGB tensor of shape [T, C, H, W] with range [-1, 1] + """ + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + raise RuntimeError(f"Failed to open video file: {video_path}") + + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + + # Convert BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame_rgb) + + cap.release() + + if not frames: + raise ValueError(f"No frames found in video: {video_path}") + + # Convert to numpy array and then tensor + frames_np = np.stack(frames, axis=0) # [T, H, W, C] + frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() # [T, C, H, W] + + # Convert from [0, 255] to [-1, 1] + frames_tensor = (frames_tensor / 127.5) - 1.0 + + return frames_tensor + + +def load_depth_from_numpy(depth_path): + """ + Load depth data from compressed NPZ file. + + Args: + depth_path: str, path to the NPZ file + + Returns: + torch.Tensor: Depth tensor of shape [T, 1, H, W] + """ + data = np.load(depth_path) + depth_np = data['depth'] # [T, H, W] + depth_tensor = torch.from_numpy(depth_np.astype(np.float32)) + + # Add channel dimension: [T, H, W] -> [T, 1, H, W] + depth_tensor = depth_tensor.unsqueeze(1) + + return depth_tensor + + +def load_mask_from_numpy(mask_path): + """ + Load mask data from compressed NPZ file. + + Args: + mask_path: str, path to the NPZ file + + Returns: + torch.Tensor: Mask tensor of shape [T, 1, H, W] + """ + data = np.load(mask_path) + mask_np = data['mask'] # [T, H, W] as bool + mask_tensor = torch.from_numpy(mask_np.astype(np.float32)) # Convert bool to float32 + + # Add channel dimension: [T, H, W] -> [T, 1, H, W] + mask_tensor = mask_tensor.unsqueeze(1) + + return mask_tensor + + +def load_camera_from_numpy(data_dir): + """ + Load camera parameters from compressed NPZ file. + + Args: + data_dir: str, directory containing camera.npz + + Returns: + tuple: (w2c_tensor, intrinsics_tensor) + - w2c_tensor: torch.Tensor of shape [T, 4, 4] + - intrinsics_tensor: torch.Tensor of shape [T, 3, 3] + """ + camera_path = os.path.join(data_dir, "camera.npz") + + if not os.path.exists(camera_path): + raise FileNotFoundError(f"camera file not found: {camera_path}") + + data = np.load(camera_path) + w2c_np = data['w2c'] + intrinsics_np = data['intrinsics'] + + w2c_tensor = torch.from_numpy(w2c_np) + intrinsics_tensor = torch.from_numpy(intrinsics_np) + + return w2c_tensor, intrinsics_tensor + + +def load_data_distributed_format(data_dir): + """Load data from distributed format (mp4 + numpy files)""" + data_path = Path(data_dir) + + # Load RGB from mp4 + cap = cv2.VideoCapture(str(data_path / "rgb.mp4")) + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + + frames_np = np.stack(frames, axis=0) + image_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() + image_tensor = (image_tensor / 127.5) - 1.0 # [0,255] -> [-1,1] + + # Load depth and mask + depth_tensor = torch.from_numpy(np.load(data_path / "depth.npz")['depth'].astype(np.float32)).unsqueeze(1) + mask_tensor = torch.from_numpy(np.load(data_path / "mask.npz")['mask'].astype(np.float32)).unsqueeze(1) + + # Load camera data + camera_data = np.load(data_path / "camera.npz") + w2c_tensor = torch.from_numpy(camera_data['w2c']) + intrinsics_tensor = torch.from_numpy(camera_data['intrinsics']) + + return image_tensor, depth_tensor, mask_tensor, w2c_tensor, intrinsics_tensor + + +def load_data_packaged_format(pt_path): + """ + Load data from the packaged pt format for backward compatibility. + + Args: + pt_path: str, path to the pt file + + Returns: + tuple: (image_tensor, depth_tensor, mask_tensor, w2c_tensor, intrinsics_tensor) + """ + data = torch.load(pt_path) + + if len(data) != 5: + raise ValueError(f"Expected 5 tensors in pt file, got {len(data)}") + + return data + + +def load_data_auto_detect(input_path): + """Auto-detect format and load data""" + input_path = Path(input_path) + + if input_path.is_file() and input_path.suffix == '.pt': + return load_data_packaged_format(input_path) + elif input_path.is_dir(): + return load_data_distributed_format(input_path) + else: + raise ValueError(f"Invalid input path: {input_path}") \ No newline at end of file diff --git a/cosmos_predict1/diffusion/inference/forward_warp_utils_pytorch.py b/cosmos_predict1/diffusion/inference/forward_warp_utils_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..3bba954a1175492d2496b1dda16526e09cc98170 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/forward_warp_utils_pytorch.py @@ -0,0 +1,721 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple +import numpy as np +import torch +import os +import torch.nn.functional as F +try: + import warp as wp +except ImportError: + raise ImportError("NVIDIA Warp is required for ray-triangle intersection") + +_warp_initialized = False +_ray_triangle_intersection_func = None + +def _init_warp(): + global _warp_initialized, _ray_triangle_intersection_func + + if not _warp_initialized: + print(f"Initializing Warp library (local_rank {os.getenv('LOCAL_RANK')})...") + wp.init() + _warp_initialized = True + print(f"Warp library initialized successfully (local_rank {os.getenv('LOCAL_RANK')})") + + if _ray_triangle_intersection_func is None: + try: + from .ray_triangle_intersection_warp import ray_triangle_intersection_warp + _ray_triangle_intersection_func = ray_triangle_intersection_warp + print(f"Warp: ray_triangle_intersection_warp kernel loaded (local_rank {os.getenv('LOCAL_RANK')})") + except ImportError: + from ray_triangle_intersection_warp import ray_triangle_intersection_warp + _ray_triangle_intersection_func = ray_triangle_intersection_warp + print(f"Warp: ray_triangle_intersection_warp kernel loaded (local_rank {os.getenv('LOCAL_RANK')})") + + +def points_to_mesh(points, mask, resolution=None): + """ + Convert a grid of 3D points to a triangle mesh based on mask. + + Args: + points: Tensor of shape [H, W, 3] containing 3D points + mask: Tensor of shape [H, W] containing binary mask + resolution: Optional tuple (new_H, new_W) to resize to + + Returns: + vertices: Tensor of shape [N, 3] containing unique vertices + faces: Tensor of shape [M, 3] containing triangle indices + """ + H, W = points.shape[:2] + + # Resize if resolution is provided + if resolution is not None: + new_H, new_W = resolution + # Resize points using bilinear interpolation + points = points.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W] + points = F.interpolate(points, size=(new_H, new_W), mode='bilinear', align_corners=False) + points = points.squeeze(0).permute(1, 2, 0) # [new_H, new_W, 3] + + # Resize mask using nearest neighbor + mask = mask.unsqueeze(0).unsqueeze(0).float() # [1, 1, H, W] + mask = F.interpolate(mask, size=(new_H, new_W), mode='nearest') + mask = mask.squeeze(0).squeeze(0).bool() # [new_H, new_W] + + H, W = new_H, new_W + + # Create vertex indices grid + vertex_indices = torch.arange(H * W, device=points.device).reshape(H, W) + + # Find 2x2 patches where at least one vertex is in the mask + # Create shifted views for efficient neighbor checking + mask_tl = mask[:-1, :-1] # top-left + mask_tr = mask[:-1, 1:] # top-right + mask_bl = mask[1:, :-1] # bottom-left + mask_br = mask[1:, 1:] # bottom-right + + # A patch is valid if any of its 4 vertices is in the mask + valid_patches = mask_tl | mask_tr | mask_bl | mask_br # [H-1, W-1] + + # Get indices of valid patches + valid_h, valid_w = torch.where(valid_patches) + + # For each valid patch, create two triangles + # Triangle 1: (u,v), (u,v+1), (u+1,v) + # Triangle 2: (u,v+1), (u+1,v+1), (u+1,v) + n_valid = len(valid_h) + + if n_valid == 0: + # No valid patches, return empty mesh + return torch.empty((0, 3), device=points.device), torch.empty((0, 3), dtype=torch.long, device=points.device) + + # Vectorized triangle creation + idx_tl = vertex_indices[valid_h, valid_w] # top-left + idx_tr = vertex_indices[valid_h, valid_w + 1] # top-right + idx_bl = vertex_indices[valid_h + 1, valid_w] # bottom-left + idx_br = vertex_indices[valid_h + 1, valid_w + 1] # bottom-right + + # Create faces (2 triangles per patch) + faces1 = torch.stack([idx_tl, idx_tr, idx_bl], dim=1) # [n_valid, 3] + faces2 = torch.stack([idx_tr, idx_br, idx_bl], dim=1) # [n_valid, 3] + faces = torch.cat([faces1, faces2], dim=0) # [2*n_valid, 3] + + # Flatten points to get vertices + vertices = points.reshape(-1, 3) # [H*W, 3] + + # Optional: Remove unused vertices and remap faces + # First, find which vertices are actually used + used_vertices = torch.unique(faces.flatten()) + + # Create a mapping from old indices to new indices + new_idx_map = torch.full((H * W,), -1, dtype=torch.long, device=points.device) + new_idx_map[used_vertices] = torch.arange(len(used_vertices), device=points.device) + + # Extract only used vertices + vertices = vertices[used_vertices] + + # Remap face indices + faces = new_idx_map[faces.flatten()].reshape(-1, 3) + + return vertices, faces + +def get_max_exponent_for_dtype(dtype): + # Set the maximum exponent based on dtype + if dtype == torch.bfloat16: + return 80.0 # Safe maximum exponent for bfloat16 + elif dtype == torch.float16: + return 10.0 # Safe maximum exponent for float16 + elif dtype == torch.float32: + return 80.0 # Safe maximum exponent for float32 + elif dtype == torch.float64: + return 700.0 # Safe maximum exponent for float64 + else: + return 80.0 # Default safe value + +def inverse_with_conversion(mtx): + return torch.linalg.inv(mtx.to(torch.float32)).to(mtx.dtype) + + +def get_camera_rays(h, w, intrinsic: np.ndarray) -> np.ndarray: + """Backproject 2D pixels into 3D rays.""" + device = intrinsic.device + x1d = torch.arange(0, w, device=device, dtype=intrinsic.dtype)[None] + y1d = torch.arange(0, h, device=device, dtype=intrinsic.dtype)[:, None] + x2d = x1d.repeat([h, 1]) # .to(intrinsic) # (h, w) + y2d = y1d.repeat([1, w]) # .to(intrinsic) # (h, w) + ones_2d = torch.ones(size=(h, w), device=device, dtype=intrinsic.dtype) # .to(intrinsic) # (h, w) + pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None] # (1, h, w, 3, 1) + + intrinsic1_inv = inverse_with_conversion(intrinsic) # (b, 3, 3) + intrinsic1_inv_4d = intrinsic1_inv[:, None, None] # (b, 1, 1, 3, 3) + # Normalize the rays + unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo).squeeze(-1) + # Normalize the rays + norm = torch.norm(unnormalized_pos, dim=-1, keepdim=True) + norm[norm == 0] = 1 + return unnormalized_pos / norm + + +def forward_warp( + frame1: torch.Tensor, + mask1: Optional[torch.Tensor], + depth1: Optional[torch.Tensor], + transformation1: Optional[torch.Tensor], + transformation2: torch.Tensor, + intrinsic1: Optional[torch.Tensor], + intrinsic2: Optional[torch.Tensor], + is_image=True, + conditioned_normal1=None, + cameraray_filtering=False, + is_depth=True, + render_depth=False, + world_points1=None, + foreground_masking=False, + boundary_mask=None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given a frame1 and global transformations transformation1 and transformation2, warps frame1 to next view using + bilinear splatting. + All arrays should be torch tensors with batch dimension and channel first + :param frame1: (b, 3, h, w). If frame1 is not in the range [-1, 1], either set is_image=False when calling + bilinear_splatting on frame within this function, or modify clipping in bilinear_splatting() + method accordingly. + :param mask1: (b, 1, h, w) - 1 for known, 0 for unknown. Optional + :param depth1: (b, 1, h, w) + :param transformation1: (b, 4, 4) extrinsic transformation matrix (camera-to-world pose) of first view. Required if depth1 is not None, or if cleaning is enabled. + :param transformation2: (b, 4, 4) extrinsic transformation matrix (camera-to-world pose) of second view. + :param intrinsic1: (b, 3, 3) camera intrinsic matrix. Required if depth1 is not None. + :param intrinsic2: (b, 3, 3) camera intrinsic matrix. Optional (defaults to intrinsic1 if provided). + :param is_image: bool, whether frame1 represents image data (affects clipping and fill value). + :param conditioned_normal1: Optional (b, 3, h, w) normals for filtering. + :param cameraray_filtering: bool, use camera rays for filtering instead of normals. + :param is_depth: bool, whether depth1 represents depth along Z or distance to camera center. Used only if depth1 is not None. + :param render_depth: bool, whether to also render and return the warped depth map. + :param world_points1: Optional (b, h, w, 3) world points. Required if depth1 is None. + :param foreground_masking: bool, enable foreground occlusion masking using mesh rendering. + :param boundary_mask: Optional (b, h, w) mask for mesh generation, required if foreground_masking is True. + """ + device = frame1.device + b, c, h, w = frame1.shape + dtype = frame1.dtype + if mask1 is None: + mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=frame1.dtype) + if intrinsic2 is None: + assert intrinsic1 is not None, "intrinsic2 cannot be derived if intrinsic1 is None and intrinsic2 is None" + intrinsic2 = intrinsic1.clone() + + if depth1 is None: + assert world_points1.shape == (b, h, w, 3) + if foreground_masking: + trans_points1, cam_points_target = project_points(world_points1, transformation2, intrinsic2, return_cam_points=True) + else: + trans_points1 = project_points(world_points1, transformation2, intrinsic2) + else: + # assert frame1.shape == (b, 3, h, w) + assert mask1.shape == (b, 1, h, w) + assert depth1.shape == (b, 1, h, w) + assert transformation1.shape == (b, 4, 4) + assert transformation2.shape == (b, 4, 4) + assert intrinsic1.shape == (b, 3, 3) + assert intrinsic2.shape == (b, 3, 3) + + depth1 = torch.nan_to_num(depth1, nan=1e4) + depth1 = torch.clamp(depth1, min=0, max=1e4) + if foreground_masking: + trans_points1, cam_points_target = compute_transformed_points( + depth1, transformation1, transformation2, intrinsic1, is_depth, intrinsic2, return_cam_points=True + ) + else: + trans_points1 = compute_transformed_points( + depth1, transformation1, transformation2, intrinsic1, is_depth, intrinsic2 + ) + mask1 = mask1 * (trans_points1[:, :, :, 2, 0].unsqueeze(1) > 0) + trans_coordinates = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0] + 1e-7) + trans_coordinates = trans_coordinates.permute(0, 3, 1, 2) # b, 2, h, w + trans_depth1 = trans_points1[:, :, :, 2, 0].unsqueeze(1) + + grid = create_grid(b, h, w, device=device, dtype=dtype) # .to(trans_coordinates) + flow12 = trans_coordinates - grid + if conditioned_normal1 is not None or cameraray_filtering: + camera_rays = get_camera_rays(h, w, intrinsic1) # b, h, w, 3 + transformation = torch.bmm(transformation2, inverse_with_conversion(transformation1)) + transformation[:, :3, 3] = 0 + trans_4d = transformation[:, None, None] + if cameraray_filtering: # use normal for filtering + conditioned_normal1 = camera_rays + inversion_vector = torch.tensor([-1, -1, -1], dtype=camera_rays.dtype, device=camera_rays.device).view( + 1, 1, 1, 3, 1 + ) + else: # use normal for filtering + assert conditioned_normal1.shape == (b, 3, h, w) + inversion_vector = torch.tensor([-1, 1, 1], dtype=camera_rays.dtype, device=camera_rays.device).view( + 1, 1, 1, 3, 1 + ) + conditioned_normal1 = conditioned_normal1.permute(0, 2, 3, 1) + # rotate normal into target camera spaces + normal_4d = conditioned_normal1.unsqueeze(-1) + b, _, h, w = depth1.shape + ones_2d = torch.ones(size=(h, w), device=device, dtype=dtype) # .to(depth1) # (h, w) + ones_4d = ones_2d[None, :, :, None, None].repeat([b, 1, 1, 1, 1]) + normal_4d_homo = torch.cat([normal_4d * inversion_vector, ones_4d], dim=3) + + trans_normal = torch.matmul(trans_4d, normal_4d_homo).squeeze(-1)[..., :3] # b, h, w, 3 + dot_product = torch.sum(trans_normal * camera_rays, dim=-1) + + # Create binary mask for angles < 90 degrees + binary_mask = dot_product > 0 + # import ipdb;ipdb.set_trace() + mask1 *= binary_mask.unsqueeze(1) + warped_frame2, mask2 = bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=is_image) + warped_depth2 = None + if render_depth or foreground_masking: + warped_depth2 = bilinear_splatting(trans_depth1, mask1, trans_depth1, flow12, None, is_image=False)[0][:, 0] + if foreground_masking: + for batch_idx in range(b): + assert boundary_mask is not None + mesh_mask = boundary_mask[batch_idx] + + mesh_downsample_factor = 4 + vertices_masked, faces_masked = points_to_mesh( + cam_points_target[batch_idx], + mesh_mask, + resolution=(h // mesh_downsample_factor, w // mesh_downsample_factor) + ) + + if vertices_masked.shape[0] == 0 or faces_masked.shape[0] == 0: + continue + + ray_scale_factor = 1 + ray_downsampled_h = h // ray_scale_factor + ray_downsampled_w = w // ray_scale_factor + current_intrinsic_batch = intrinsic2[batch_idx:batch_idx+1] + scaled_intrinsic = current_intrinsic_batch.clone() + + scaled_intrinsic[0, 0, 0] /= ray_scale_factor # fx + scaled_intrinsic[0, 1, 1] /= ray_scale_factor # fy + scaled_intrinsic[0, 0, 2] /= ray_scale_factor # cx + scaled_intrinsic[0, 1, 2] /= ray_scale_factor # cy + + camera_rays = get_camera_rays(ray_downsampled_h, ray_downsampled_w, scaled_intrinsic) # (1, h_ds, w_ds, 3) + camera_rays = camera_rays[0] # (h_ds, w_ds, 3) + + ray_origins = torch.zeros((ray_downsampled_h, ray_downsampled_w, 3), device=device, dtype=dtype) + + mesh_depth = ray_triangle_intersection( + ray_origins, + camera_rays, + vertices_masked, + faces_masked, + device + ) + ray_z = camera_rays[:, :, 2] # (h, w) + mesh_z_depth = mesh_depth * ray_z # Convert to z-depth + mesh_z_depth = F.interpolate(mesh_z_depth.unsqueeze(0).unsqueeze(0), size=(h, w), mode='bilinear').squeeze(0).squeeze(0) + + warped_depth_batch = warped_depth2[batch_idx] # (h, w) + + + mesh_valid = mesh_z_depth > 0 + mesh_closer = ((mesh_z_depth + 0.02) < warped_depth_batch) & mesh_valid + + mask2[batch_idx, 0] = mask2[batch_idx, 0] * (~mesh_closer).float() + warped_frame2[batch_idx] = (warped_frame2[batch_idx] + 1) * (~mesh_closer.unsqueeze(0)).float() - 1 + warped_depth2[batch_idx] = warped_depth2[batch_idx] * (~mesh_closer.unsqueeze(0)).float() + return warped_frame2, mask2, warped_depth2, flow12 + +def reliable_depth_mask_range_batch(depth, window_size=5, ratio_thresh=0.05, eps=1e-6): + assert window_size % 2 == 1, "Window size must be odd." + if depth.dim() == 3: # Input shape: (B, H, W) + depth_unsq = depth.unsqueeze(1) + elif depth.dim() == 4: # Already has shape (B, 1, H, W) + depth_unsq = depth + else: + raise ValueError("depth tensor must be of shape (B, H, W) or (B, 1, H, W)") + + local_max = torch.nn.functional.max_pool2d(depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2) + local_min = -torch.nn.functional.max_pool2d(-depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2) + local_mean = torch.nn.functional.avg_pool2d(depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2) + ratio = (local_max - local_min) / (local_mean + eps) + reliable_mask = (ratio < ratio_thresh) & (depth_unsq > 0) + + return reliable_mask + +def double_forward_warp( + frame1: torch.Tensor, + mask1: torch.Tensor, + depth1: torch.Tensor, + intrinsic1: torch.Tensor, + double_proj_w2cs: torch.Tensor, +): + """ + Double projection using forward warping with your APIs. + + 1. Warps frame1 from the original view (identity transformation) + to the target view defined by double_proj_w2cs. + 2. Computes a warped flow field and then warps the intermediate result + back to the original view using the original depth. + + :param frame1: (b, 3, h, w) original image. + :param mask1: (b, 1, h, w) valid mask. + :param depth1: (b, 1, h, w) depth map. + :param intrinsic1: (b, 3, 3) intrinsic matrix. + :param double_proj_w2cs: (b, 4, 4) target view transformation. + :return: twice_warped_frame1, warped_frame2, None, None + """ + b, c, h, w = frame1.shape + device, dtype = frame1.device, frame1.dtype + + if mask1 is None: + mask1 = torch.ones((b, 1, h, w), device=device, dtype=dtype) + + # Use identity transformation for the original view. + identity = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(b, 1, 1) + + trans_points = compute_transformed_points( + depth1, identity, double_proj_w2cs, intrinsic1, is_depth=True, intrinsic2=intrinsic1 + ) + trans_coordinates = trans_points[:, :, :, :2, 0] / (trans_points[:, :, :, 2:3, 0] + 1e-7) + trans_depth = trans_points[:, :, :, 2, 0] + + grid = create_grid(b, h, w, device=device, dtype=dtype) + flow12 = trans_coordinates.permute(0, 3, 1, 2) - grid + + warped_frame2, mask2 = bilinear_splatting( + frame1, mask1, trans_depth.unsqueeze(1), flow12, None, is_image=True, n_views=1, depth_weight_scale=50 + ) + + warped_flow, _ = bilinear_splatting( + flow12, mask1, trans_depth.unsqueeze(1), flow12, None, is_image=False, n_views=1, depth_weight_scale=50 + ) + + twice_warped_frame1, twice_warped_mask1 = bilinear_splatting( + warped_frame2, mask2, depth1, -warped_flow, None, is_image=True, n_views=1, depth_weight_scale=50 + ) + + return twice_warped_frame1, twice_warped_mask1, warped_frame2, mask2 + + +def unproject_points(depth: torch.Tensor, + w2c: torch.Tensor, + intrinsic: torch.Tensor, + is_depth: bool = True, + mask: Optional[torch.Tensor] = None): + + b, _, h, w = depth.shape + device = depth.device + dtype = depth.dtype + if mask is None: + mask = depth > 0 + if mask.dim() == depth.dim() and mask.shape[1] == 1: + mask = mask[:, 0] + + idx = torch.nonzero(mask) + if idx.numel() == 0: + return torch.zeros((b, h, w, 3), device=device, dtype=dtype) + + b_idx, y_idx, x_idx = idx[:, 0], idx[:, 1], idx[:, 2] + + + intrinsic_inv = inverse_with_conversion(intrinsic) # (b, 3, 3) + + x_valid = x_idx.to(dtype) + y_valid = y_idx.to(dtype) + ones = torch.ones_like(x_valid) + pos = torch.stack([x_valid, y_valid, ones], dim=1).unsqueeze(-1) # (N, 3, 1) + + intrinsic_inv_valid = intrinsic_inv[b_idx] # (N, 3, 3) + unnormalized_pos = torch.matmul(intrinsic_inv_valid, pos) # (N, 3, 1) + + depth_valid = depth[b_idx, 0, y_idx, x_idx].view(-1, 1, 1) + if is_depth: + world_points_cam = depth_valid * unnormalized_pos + else: + norm_val = torch.norm(unnormalized_pos, dim=1, keepdim=True) + direction = unnormalized_pos / (norm_val + 1e-8) + world_points_cam = depth_valid * direction + + ones_h = torch.ones((world_points_cam.shape[0], 1, 1), + device=device, dtype=dtype) + world_points_homo = torch.cat([world_points_cam, ones_h], dim=1) # (N, 4, 1) + + trans = inverse_with_conversion(w2c) # (b, 4, 4) + trans_valid = trans[b_idx] # (N, 4, 4) + world_points_transformed = torch.matmul(trans_valid, world_points_homo) # (N, 4, 1) + sparse_points = world_points_transformed[:, :3, 0] # (N, 3) + + out_points = torch.zeros((b, h, w, 3), device=device, dtype=dtype) + out_points[b_idx, y_idx, x_idx, :] = sparse_points + return out_points + +def project_points(world_points: torch.Tensor, w2c: torch.Tensor, intrinsic: torch.Tensor, return_cam_points: bool = False): + """ + Projects 3D world points back into 2D pixel space. + """ + world_points = world_points.unsqueeze(-1) # (b, h, w, 3) -> # (b, h, w, 3, 1) + b, h, w, _, _ = world_points.shape + + ones_4d = torch.ones((b, h, w, 1, 1), device=world_points.device, dtype=world_points.dtype) # (b, h, w, 1, 1) + world_points_homo = torch.cat([world_points, ones_4d], dim=3) # (b, h, w, 4, 1) + + # Apply transformation2 to convert world points to camera space + trans_4d = w2c[:, None, None] # (b, 1, 1, 4, 4) + camera_points_homo = torch.matmul(trans_4d, world_points_homo) # (b, h, w, 4, 1) + + # Remove homogeneous coordinate and project to image plane + camera_points = camera_points_homo[:, :, :, :3] # (b, h, w, 3, 1) + intrinsic_4d = intrinsic[:, None, None] # (b, 1, 1, 3, 3) + projected_points = torch.matmul(intrinsic_4d, camera_points) # (b, h, w, 3, 1) + + if return_cam_points: + # Return both projected points and camera space points + cam_points_3d = camera_points.squeeze(-1) # (b, h, w, 3) + return projected_points, cam_points_3d + else: + return projected_points + + +def unproject_depth_torch( + depth1: torch.Tensor, + transformation1: torch.Tensor, + intrinsic1: torch.Tensor, +) -> torch.Tensor: + b, c, h, w = depth1.shape + assert depth1.shape == (b, 1, h, w) + assert transformation1.shape == (b, 4, 4) + assert intrinsic1.shape == (b, 3, 3) + device = depth1.device + x1d = torch.arange(0, w, device=device)[None] + y1d = torch.arange(0, h, device=device)[:, None] + x2d = x1d.repeat([h, 1]) # .to(depth1) # (h, w) + y2d = y1d.repeat([1, w]) # .to(depth1) # (h, w) + ones_2d = torch.ones(size=(h, w), device=device) # .to(depth1) # (h, w) + ones_4d = ones_2d[None, :, :, None, None].repeat([b, 1, 1, 1, 1]) # (b, h, w, 1, 1) + pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None] # (1, h, w, 3, 1) + + intrinsic1_inv = inverse_with_conversion(intrinsic1) # (b, 3, 3) + intrinsic1_inv_4d = intrinsic1_inv[:, None, None] # (b, 1, 1, 3, 3) + + depth_4d = depth1[:, 0][:, :, :, None, None] # (b, h, w, 1, 1) + + unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo) # (b, h, w, 3, 1) + world_points = depth_4d * unnormalized_pos # (b, h, w, 3, 1) + + world_points_homo = torch.cat([world_points, ones_4d], dim=3) # (b, h, w, 4, 1) + trans_4d = transformation1[:, None, None] # (b, 1, 1, 4, 4) + trans_world_homo = torch.matmul(trans_4d, world_points_homo) # (b, h, w, 4, 1) + trans_world = trans_world_homo[:, :, :, :3] # (b, h, w, 3, 1) + trans_world = trans_world.squeeze(dim=-1) + return trans_world + + +def compute_transformed_points( + depth1: torch.Tensor, + transformation1: torch.Tensor, + transformation2: torch.Tensor, + intrinsic1: torch.Tensor, + is_depth: bool = True, + intrinsic2: Optional[torch.Tensor] = None, + return_cam_points: bool = False, +): + """ + Computes transformed position for each pixel location + """ + b, _, h, w = depth1.shape + if intrinsic2 is None: + intrinsic2 = intrinsic1.clone() + transformation = torch.bmm( + transformation2, inverse_with_conversion(transformation1) + ) # (b, 4, 4) transformation is w2c + device = depth1.device + x1d = torch.arange(0, w, device=device, dtype=depth1.dtype)[None] + y1d = torch.arange(0, h, device=device, dtype=depth1.dtype)[:, None] + x2d = x1d.repeat([h, 1]) # .to(depth1) # (h, w) + y2d = y1d.repeat([1, w]) # .to(depth1) # (h, w) + ones_2d = torch.ones(size=(h, w), device=device, dtype=depth1.dtype) # .to(depth1) # (h, w) + ones_4d = ones_2d[None, :, :, None, None].repeat([b, 1, 1, 1, 1]) # (b, h, w, 1, 1) + pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None] # (1, h, w, 3, 1) + + intrinsic1_inv = inverse_with_conversion(intrinsic1) # (b, 3, 3) + intrinsic1_inv_4d = intrinsic1_inv[:, None, None] # (b, 1, 1, 3, 3) + intrinsic2_4d = intrinsic2[:, None, None] # (b, 1, 1, 3, 3) + depth_4d = depth1[:, 0][:, :, :, None, None] # (b, h, w, 1, 1) + trans_4d = transformation[:, None, None] # (b, 1, 1, 4, 4) + + unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo) # (b, h, w, 3, 1) + if is_depth: + world_points = depth_4d * unnormalized_pos # (b, h, w, 3, 1) + else: # if 'depth' is defined as distance to camera center + direction_vectors = unnormalized_pos / torch.norm(unnormalized_pos, dim=-2, keepdim=True) # (b, h, w, 3, 1) + world_points = depth_4d * direction_vectors # (b, h, w, 3, 1) + + world_points_homo = torch.cat([world_points, ones_4d], dim=3) # (b, h, w, 4, 1) + trans_world_homo = torch.matmul(trans_4d, world_points_homo) # (b, h, w, 4, 1) + trans_world = trans_world_homo[:, :, :, :3] # (b, h, w, 3, 1) + trans_norm_points = torch.matmul(intrinsic2_4d, trans_world) # (b, h, w, 3, 1) + + if return_cam_points: + # Return both projected points and camera space points + cam_points = trans_world.squeeze(-1) # (b, h, w, 3) + return trans_norm_points, cam_points + else: + return trans_norm_points + + +def bilinear_splatting( + frame1: torch.Tensor, + mask1: Optional[torch.Tensor], + depth1: torch.Tensor, + flow12: torch.Tensor, + flow12_mask: Optional[torch.Tensor], + is_image: bool = False, + n_views=1, + depth_weight_scale=50, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Bilinear splatting + :param frame1: (b,c,h,w) + :param mask1: (b,1,h,w): 1 for known, 0 for unknown. Optional + :param depth1: (b,1,h,w) + :param flow12: (b,2,h,w) + :param flow12_mask: (b,1,h,w): 1 for valid flow, 0 for invalid flow. Optional + :param is_image: if true, output will be clipped to (-1,1) range + :return: warped_frame2: (b,c,h,w) + mask2: (b,1,h,w): 1 for known and 0 for unknown + """ + b, c, h, w = frame1.shape + device = frame1.device + dtype = frame1.dtype + if mask1 is None: + mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype) # .to(frame1) + if flow12_mask is None: + flow12_mask = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype) # .to(flow12) + grid = create_grid(b, h, w, device=device, dtype=dtype).to(dtype) # .to(frame1) + trans_pos = flow12 + grid + + trans_pos_offset = trans_pos + 1 + trans_pos_floor = torch.floor(trans_pos_offset).long() + trans_pos_ceil = torch.ceil(trans_pos_offset).long() + trans_pos_offset = torch.stack( + [torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)], + dim=1, + ) + trans_pos_floor = torch.stack( + [torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)], + dim=1, + ) + trans_pos_ceil = torch.stack( + [torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)], + dim=1, + ) + + prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * ( + 1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]) + ) + prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * ( + 1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]) + ) + prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * ( + 1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]) + ) + prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * ( + 1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]) + ) + + # Calculate depth weights, preventing overflow and removing saturation + # Clamp depth to be non-negative before log1p + clamped_depth1 = torch.clamp(depth1, min=0) + log_depth1 = torch.log1p(clamped_depth1) # Use log1p for better precision near 0 + # Normalize and scale log depth + exponent = log_depth1 / (log_depth1.max() + 1e-7) * depth_weight_scale + # Clamp exponent before exp to prevent overflow + max_exponent = get_max_exponent_for_dtype(depth1.dtype) + clamped_exponent = torch.clamp(exponent, max=max_exponent) + # Compute depth weights with added epsilon for stability when dividing later + depth_weights = torch.exp(clamped_exponent) + 1e-7 + + + weight_nw = torch.moveaxis(prox_weight_nw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) + weight_sw = torch.moveaxis(prox_weight_sw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) + weight_ne = torch.moveaxis(prox_weight_ne * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) + weight_se = torch.moveaxis(prox_weight_se * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) + + warped_frame = torch.zeros(size=(b, h + 2, w + 2, c), dtype=dtype, device=device) # .to(frame1) + warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1), dtype=dtype, device=device) # .to(frame1) + + frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2]) + batch_indices = torch.arange(b, device=device, dtype=torch.long)[:, None, None] # .to(frame1.device) + warped_frame.index_put_( + (batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), frame1_cl * weight_nw, accumulate=True + ) + warped_frame.index_put_( + (batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), frame1_cl * weight_sw, accumulate=True + ) + warped_frame.index_put_( + (batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), frame1_cl * weight_ne, accumulate=True + ) + warped_frame.index_put_( + (batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), frame1_cl * weight_se, accumulate=True + ) + + warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), weight_nw, accumulate=True) + warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), weight_sw, accumulate=True) + warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), weight_ne, accumulate=True) + warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), weight_se, accumulate=True) + if n_views > 1: + warped_frame = warped_frame.reshape(b // n_views, n_views, h + 2, w + 2, c).sum(1) + warped_weights = warped_weights.reshape(b // n_views, n_views, h + 2, w + 2, 1).sum(1) + + warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3], [0, 2, 3, 1]) + warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3], [0, 2, 3, 1]) + cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1] + cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1] + cropped_weights = torch.nan_to_num(cropped_weights, nan=1000.0) + + mask = cropped_weights > 0 + zero_value = -1 if is_image else 0 + zero_tensor = torch.tensor(zero_value, dtype=frame1.dtype, device=frame1.device) + warped_frame2 = torch.where(mask, cropped_warped_frame / cropped_weights, zero_tensor) + mask2 = mask.to(frame1) + if is_image: + # assert warped_frame2.min() >= -1.1 # Allow for rounding errors + # assert warped_frame2.max() <= 1.1 + warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1) + return warped_frame2, mask2 + +def create_grid(b: int, h: int, w: int, device="cpu", dtype=torch.float) -> torch.Tensor: + """ + Create a dense grid of (x,y) coordinates of shape (b, 2, h, w). + """ + x = torch.arange(0, w, device=device, dtype=dtype).view(1, 1, 1, w).expand(b, 1, h, w) + y = torch.arange(0, h, device=device, dtype=dtype).view(1, 1, h, 1).expand(b, 1, h, w) + return torch.cat([x, y], dim=1) + +def ray_triangle_intersection( + ray_origins: torch.Tensor, # (H, W, 3) + ray_directions: torch.Tensor, # (H, W, 3) + vertices: torch.Tensor, # (N, 3) + faces: torch.Tensor, # (M, 3) + device: torch.device +) -> torch.Tensor: + """ + Compute ray-triangle intersections for all rays and triangles. + Returns depth map of shape (H, W) with intersection distances. + + Uses NVIDIA Warp acceleration for fast performance. + """ + _init_warp() + return _ray_triangle_intersection_func( + ray_origins, ray_directions, vertices, faces, device + ) \ No newline at end of file diff --git a/cosmos_predict1/diffusion/inference/gen3c_dynamic.py b/cosmos_predict1/diffusion/inference/gen3c_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fb6367de4f136b081a5222ab513ef5489bc3b0 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/gen3c_dynamic.py @@ -0,0 +1,363 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import torch +import numpy as np +from cosmos_predict1.diffusion.inference.inference_utils import ( + add_common_arguments, +) +from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video +from cosmos_predict1.diffusion.inference.cache_3d import Cache4D +from cosmos_predict1.diffusion.inference.camera_utils import generate_camera_trajectory +from cosmos_predict1.diffusion.inference.data_loader_utils import load_data_auto_detect +import torch.nn.functional as F +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) # TODO: do we need this? + parser.add_argument( + "--input_image_path", + type=str, + help="Input image path for generating a single video", + ) + parser.add_argument( + "--trajectory", + type=str, + choices=[ + "left", + "right", + "up", + "down", + "zoom_in", + "zoom_out", + "clockwise", + "counterclockwise", + ], + default="left", + help="Select a trajectory type from the available options (default: original)", + ) + parser.add_argument( + "--camera_rotation", + type=str, + choices=["center_facing", "no_rotation", "trajectory_aligned"], + default="center_facing", + help="Controls camera rotation during movement: center_facing (rotate to look at center), no_rotation (keep orientation), or trajectory_aligned (rotate in the direction of movement)", + ) + parser.add_argument( + "--movement_distance", + type=float, + default=0.3, + help="Distance of the camera from the center of the scene", + ) + parser.add_argument( + "--save_buffer", + action="store_true", + help="If set, save the warped images (buffer) side by side with the output video.", + ) + parser.add_argument( + "--filter_points_threshold", + type=float, + default=0.05, + help="If set, filter the points continuity of the warped images.", + ) + parser.add_argument( + "--foreground_masking", + action="store_true", + help="If set, use foreground masking for the warped images.", + ) + return parser.parse_args() + +def validate_args(args): + assert args.num_video_frames is not None, "num_video_frames must be provided" + assert (args.num_video_frames - 1) % 120 == 0, "num_video_frames must be 121, 241, 361, ... (N*120+1)" + + +def demo(args): + """Run video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "video2world" + validate_args(args) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize video2world generation model pipeline + pipeline = Gen3cPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name="Gen3C-Cosmos-7B", + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=121, + seed=args.seed, + ) + + sample_n_frames = pipeline.model.chunk_size + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": args.prompt, "visual_input": args.input_image_path}] + + os.makedirs(os.path.dirname(args.video_save_folder), exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None and args.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + continue + current_video_path = input_dict.get("visual_input", None) + if current_video_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + + # Load data using the new auto-detect loader (supports both old pt and new format) + try: + ( + image_bchw_float, + depth_b1hw, + mask_b1hw, + initial_w2c_b44, + intrinsics_b33, + ) = load_data_auto_detect(current_video_path) + except Exception as e: + log.critical(f"Failed to load visual input from {current_video_path}: {e}") + continue + + image_bchw_float = image_bchw_float.to(device) + depth_b1hw = depth_b1hw.to(device) + mask_b1hw = mask_b1hw.to(device) + initial_w2c_b44 = initial_w2c_b44.to(device) + intrinsics_b33 = intrinsics_b33.to(device) + + cache = Cache4D( + input_image=image_bchw_float.clone(), # [B, C, H, W] + input_depth=depth_b1hw, # [B, 1, H, W] + input_mask=mask_b1hw, # [B, 1, H, W] + input_w2c=initial_w2c_b44, # [B, 4, 4] + input_intrinsics=intrinsics_b33,# [B, 3, 3] + filter_points_threshold=args.filter_points_threshold, + input_format=["F", "C", "H", "W"], + foreground_masking=args.foreground_masking, + ) + + initial_cam_w2c_for_traj = initial_w2c_b44 + initial_cam_intrinsics_for_traj = intrinsics_b33 + + # Generate camera trajectory using the new utility function + try: + generated_w2cs, generated_intrinsics = generate_camera_trajectory( + trajectory_type=args.trajectory, + initial_w2c=initial_cam_w2c_for_traj, + initial_intrinsics=initial_cam_intrinsics_for_traj, + num_frames=args.num_video_frames, + movement_distance=args.movement_distance, + camera_rotation=args.camera_rotation, + center_depth=1.0, + device=device.type, + ) + except (ValueError, NotImplementedError) as e: + log.critical(f"Failed to generate trajectory: {e}") + continue + + log.info(f"Generating 0 - {sample_n_frames} frames") + + rendered_warp_images, rendered_warp_masks = cache.render_cache( + generated_w2cs[:, 0:sample_n_frames], + generated_intrinsics[:, 0:sample_n_frames], + start_frame_idx=0, + ) + + all_rendered_warps = [] + if args.save_buffer: + all_rendered_warps.append(rendered_warp_images.clone().cpu()) + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_path=image_bchw_float[0].unsqueeze(0).unsqueeze(2), + negative_prompt=args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + video, prompt = generated_output + + num_ar_iterations = (generated_w2cs.shape[1] - 1) // (sample_n_frames - 1) + for num_iter in range(1, num_ar_iterations): + start_frame_idx = num_iter * (sample_n_frames - 1) # Overlap by 1 frame + end_frame_idx = start_frame_idx + sample_n_frames + + log.info(f"Generating {start_frame_idx} - {end_frame_idx} frames") + + last_frame_hwc_0_255 = torch.tensor(video[-1], device=device) + pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] + + current_segment_w2cs = generated_w2cs[:, start_frame_idx:end_frame_idx] + current_segment_intrinsics = generated_intrinsics[:, start_frame_idx:end_frame_idx] + rendered_warp_images, rendered_warp_masks = cache.render_cache( + current_segment_w2cs, + current_segment_intrinsics, + start_frame_idx=start_frame_idx, + ) + + if args.save_buffer: + all_rendered_warps.append(rendered_warp_images[:, 1:].clone().cpu()) + + + pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 # (B,C,T,H,W), range [-1,1] + generated_output = pipeline.generate( + prompt=current_prompt, + image_path=pred_image_for_depth_bcthw_minus1_1, + negative_prompt=args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + video_new, prompt = generated_output + video = np.concatenate([video, video_new[1:]], axis=0) + + # Final video processing + final_video_to_save = video + final_width = args.width + + if args.save_buffer and all_rendered_warps: + squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] # Each is (T_chunk, n_i, C, H, W) + + if squeezed_warps: + n_max = max(t.shape[1] for t in squeezed_warps) + + padded_t_list = [] + for sq_t in squeezed_warps: + # sq_t shape: (T_chunk, n_i, C, H, W) + current_n_i = sq_t.shape[1] + padding_needed_dim1 = n_max - current_n_i + + pad_spec = (0,0, # W + 0,0, # H + 0,0, # C + 0,padding_needed_dim1, # n_i + 0,0) # T_chunk + padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0) + padded_t_list.append(padded_t) + + full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0) + + T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape + buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4) + buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim) + buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0 + buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8) + buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1)) + + final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2) + final_width = args.width * (1 + n_max) + log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}") + else: + log.info("No warp buffers to save.") + + + if args.batch_input_path: + video_save_path = os.path.join(args.video_save_folder, f"{i}.mp4") + else: + video_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + + os.makedirs(os.path.dirname(video_save_path), exist_ok=True) + + # Save video + save_video( + video=final_video_to_save, + fps=args.fps, + H=args.height, + W=final_width, + video_save_quality=5, + video_save_path=video_save_path, + ) + log.info(f"Saved video to {video_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + if args.prompt is None: + args.prompt = "" + args.disable_guardrail = True + args.disable_prompt_upsampler = True + demo(args) diff --git a/cosmos_predict1/diffusion/inference/gen3c_persistent.py b/cosmos_predict1/diffusion/inference/gen3c_persistent.py new file mode 100644 index 0000000000000000000000000000000000000000..8113d53f306781282fad2dafc700a79ae60cf266 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/gen3c_persistent.py @@ -0,0 +1,569 @@ +import argparse +import os +import time + +from moge.model.v1 import MoGeModel +import torch +import numpy as np +from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline +from cosmos_predict1.diffusion.inference.gen3c_single_image import ( + create_parser as create_parser_base, + validate_args as validate_args_base, + _predict_moge_depth, + _predict_moge_depth_from_tensor +) +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.distributed import device_with_rank, is_rank0, get_rank +from cosmos_predict1.utils.io import save_video +from cosmos_predict1.diffusion.inference.cache_3d import Cache3D_Buffer, Cache4D +import torch.nn.functional as F + + +def create_parser(): + return create_parser_base() + + +def validate_args(args: argparse.Namespace): + validate_args_base(args) + assert args.batch_input_path is None, "Unsupported in persistent mode" + assert args.prompt is not None, "Prompt is required in persistent mode (but it can be the empty string)" + assert args.input_image_path is None, "Image should be provided directly by value in persistent mode" + assert args.trajectory in (None, 'none'), "Trajectory should be provided directly by value in persistent mode, set --trajectory=none" + assert not args.video_save_name, f"Video saving name will be set automatically for each inference request. Found string: \"{args.video_save_name}\"" + + +def resize_intrinsics(intrinsics: np.ndarray | torch.Tensor, + old_size: tuple[int, int], new_size: tuple[int, int], + crop_size: tuple[int, int] | None = None) -> np.ndarray | torch.Tensor: + # intrinsics: (3, 3) + # old_size: (h1, w1) + # new_size: (h2, w2) + if isinstance(intrinsics, np.ndarray): + intrinsics_copy = np.copy(intrinsics) + elif isinstance(intrinsics, torch.Tensor): + intrinsics_copy = intrinsics.clone() + else: + raise ValueError(f"Invalid intrinsics type: {type(intrinsics)}") + intrinsics_copy[:, 0, :] *= new_size[1] / old_size[1] + intrinsics_copy[:, 1, :] *= new_size[0] / old_size[0] + if crop_size is not None: + intrinsics_copy[:, 0, -1] = intrinsics_copy[:, 0, -1] - (new_size[1] - crop_size[1]) / 2 + intrinsics_copy[:, 1, -1] = intrinsics_copy[:, 1, -1] - (new_size[0] - crop_size[0]) / 2 + return intrinsics_copy + + +class Gen3cPersistentModel(): + """Helper class to run Gen3C image-to-video or video-to-video inference. + + This class loads the models only once and can be reused for multiple inputs. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + """ + + @torch.no_grad() + def __init__(self, args: argparse.Namespace): + misc.set_random_seed(args.seed) + validate_args(args) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + self.frames_per_batch = 121 + self.inference_overlap_frames = 1 + + # Initialize video2world generation model pipeline + pipeline = Gen3cPipeline( + inference_type="video2world", + checkpoint_dir=args.checkpoint_dir, + checkpoint_name="Gen3C-Cosmos-7B", + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=self.frames_per_batch, + seed=args.seed, + ) + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + self.args = args + self.frame_buffer_max = pipeline.model.frame_buffer_max + self.generator = torch.Generator(device=device).manual_seed(args.seed) + self.sample_n_frames = pipeline.model.chunk_size + self.moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) + self.pipeline = pipeline + self.device = device + self.device_with_rank = device_with_rank(self.device) + + self.cache: Cache3D_Buffer | Cache4D | None = None + self.model_was_seeded = False + # User-provided seeding image, after pre-processing. + # Shape [B, C, T, H, W], type float, range [-1, 1]. + self.seeding_image: torch.Tensor | None = None + + + @torch.no_grad() + def seed_model_from_values(self, + images_np: np.ndarray, + depths_np: np.ndarray | None, + world_to_cameras_np: np.ndarray, + focal_lengths_np: np.ndarray, + principal_point_rel_np: np.ndarray, + resolutions: np.ndarray, + masks_np: np.ndarray | None = None): + import torchvision.transforms.functional as transforms_F + + # Check inputs + n = images_np.shape[0] + assert images_np.shape[-1] == 3 + assert world_to_cameras_np.shape == (n, 4, 4) + assert focal_lengths_np.shape == (n, 2) + assert principal_point_rel_np.shape == (n, 2) + assert resolutions.shape == (n, 2) + assert (depths_np is None) or (depths_np.shape == images_np.shape[:-1]) + assert (masks_np is None) or (masks_np.shape == images_np.shape[:-1]) + + + if n == 1: + # TODO: allow user to provide depths, extrinsics and intrinsics + assert depths_np is None, "Not supported yet: directly providing pre-estimated depth values along with a single image." + + # Note: image is received as 0..1 float, but MoGE expects 0..255 uint8. + input_image_np = images_np[0, ...] * 255.0 + del images_np + + # Predict depth and initialize 3D cache. + # Note: even though internally MoGE may use a different resolution, all of the outputs + # are properly resized & adapted to our desired (self.args.height, self.args.width) resolution, + # including the intrinsics. + ( + moge_image_b1chw_float, + moge_depth_b11hw, + moge_mask_b11hw, + moge_initial_w2c_b144, + moge_intrinsics_b133, + ) = _predict_moge_depth( + input_image_np, self.args.height, self.args.width, self.device_with_rank, self.moge_model + ) + + # TODO: MoGE provides camera params, is it okay to just ignore the user-provided ones? + input_image = moge_image_b1chw_float[:, 0].clone() + self.cache = Cache3D_Buffer( + frame_buffer_max=self.frame_buffer_max, + generator=self.generator, + noise_aug_strength=self.args.noise_aug_strength, + input_image=input_image, # [B, C, H, W] + input_depth=moge_depth_b11hw[:, 0], # [B, 1, H, W] + # input_mask=moge_mask_b11hw[:, 0], # [B, 1, H, W] + input_w2c=moge_initial_w2c_b144[:, 0], # [B, 4, 4] + input_intrinsics=moge_intrinsics_b133[:, 0], # [B, 3, 3] + filter_points_threshold=self.args.filter_points_threshold, + foreground_masking=self.args.foreground_masking, + ) + + seeding_image = input_image_np.transpose(2, 0, 1)[None, ...] / 128.0 - 1.0 + seeding_image = torch.from_numpy(seeding_image).to(device_with_rank(self.device_with_rank)) + + # Return the estimated extrinsics and intrinsics in the same format as the input + estimated_w2c_b44_np = moge_initial_w2c_b144.cpu().numpy()[:, 0, ...] + moge_intrinsics_b133_np = moge_intrinsics_b133.cpu().numpy() + estimated_focal_lengths_b2_np = np.stack([moge_intrinsics_b133_np[:, 0, 0, 0], + moge_intrinsics_b133_np[:, 0, 1, 1]], axis=1) + estimated_principal_point_rel_b2_np = moge_intrinsics_b133_np[:, 0, :2, 2] + + else: + if depths_np is None: + raise NotImplementedError("Seeding from multiple frames requires providing depth values.") + if masks_np is None: + raise NotImplementedError("Seeding from multiple frames requires providing mask values.") + + # RGB: [B, H, W, C] to [B, C, H, W] + image_bchw_float = torch.from_numpy(images_np.transpose(0, 3, 1, 2).astype(np.float32)).to(self.device_with_rank) + # Images are received as 0..1 float32, we convert to -1..1 range. + image_bchw_float = (image_bchw_float * 2.0) - 1.0 + del images_np + + # Depth: [B, H, W] to [B, 1, H, W] + depth_b1hw = torch.from_numpy(depths_np[:, None, ...].astype(np.float32)).to(self.device_with_rank) + # Mask: [B, H, W] to [B, 1, H, W] + mask_b1hw = torch.from_numpy(masks_np[:, None, ...].astype(np.float32)).to(self.device_with_rank) + # World-to-camera: [B, 4, 4] + initial_w2c_b44 = torch.from_numpy(world_to_cameras_np).to(self.device_with_rank) + # Intrinsics: [B, 3, 3] + intrinsics_b33_np = np.zeros((n, 3, 3), dtype=np.float32) + intrinsics_b33_np[:, 0, 0] = focal_lengths_np[:, 0] + intrinsics_b33_np[:, 1, 1] = focal_lengths_np[:, 1] + intrinsics_b33_np[:, 0, 2] = principal_point_rel_np[:, 0] * self.args.width + intrinsics_b33_np[:, 1, 2] = principal_point_rel_np[:, 1] * self.args.height + intrinsics_b33_np[:, 2, 2] = 1.0 + intrinsics_b33 = torch.from_numpy(intrinsics_b33_np).to(self.device_with_rank) + + self.cache = Cache4D( + input_image=image_bchw_float.clone(), # [B, C, H, W] + input_depth=depth_b1hw, # [B, 1, H, W] + input_mask=mask_b1hw, # [B, 1, H, W] + input_w2c=initial_w2c_b44, # [B, 4, 4] + input_intrinsics=intrinsics_b33, # [B, 3, 3] + filter_points_threshold=self.args.filter_points_threshold, + foreground_masking=self.args.foreground_masking, + input_format=["F", "C", "H", "W"], + ) + + # Return the given extrinsics and intrinsics in the same format as the input + seeding_image = image_bchw_float + estimated_w2c_b44_np = world_to_cameras_np + estimated_focal_lengths_b2_np = focal_lengths_np + estimated_principal_point_rel_b2_np = principal_point_rel_np + + # Resize seeding image to match the desired resolution. + if (seeding_image.shape[2] != self.H) or (seeding_image.shape[3] != self.W): + # TODO: would it be better to crop if aspect ratio is off? + seeding_image = transforms_F.resize( + seeding_image, + size=(self.H, self.W), # type: ignore + interpolation=transforms_F.InterpolationMode.BICUBIC, + antialias=True, + ) + # Switch from [B, C, H, W] to [B, C, T, H, W]. + self.seeding_image = seeding_image[:, :, None, ...] + + working_resolutions_b2_np = np.tile([[self.args.width, self.args.height]], (n, 1)) + return ( + estimated_w2c_b44_np, + estimated_focal_lengths_b2_np, + estimated_principal_point_rel_b2_np, + working_resolutions_b2_np + ) + + + @torch.no_grad() + def inference_on_cameras(self, view_cameras_w2cs: np.ndarray, view_camera_intrinsics: np.ndarray, + fps: int | float, + overlap_frames:int = 1, + return_estimated_depths: bool = False, + video_save_quality: int = 5, + save_buffer: bool | None = None) -> dict | None: + + # TODO: this is not safe if multiple inference requests are served in parallel. + # TODO: also, it's not 100% clear whether it is correct to override this request + # after initialization of the pipeline. + self.pipeline.fps = int(fps) + del fps + save_buffer = save_buffer if (save_buffer is not None) else self.args.save_buffer + + video_save_name = self.args.video_save_name + if not video_save_name: + video_save_name = f"video_{time.strftime('%Y-%m-%d_%H-%M-%S')}" + video_save_path = os.path.join(self.args.video_save_folder, f"{video_save_name}.mp4") + os.makedirs(self.args.video_save_folder, exist_ok=True) + + cache_is_multiframe = isinstance(self.cache, Cache4D) + + # Note: the inference server already adjusted intrinsics to match our + # inference resolution (self.W, self.H), so this call is just to make sure + # that all tensors have the right shape, etc. + view_cameras_w2cs, view_camera_intrinsics = self.prepare_camera_for_inference( + view_cameras_w2cs, view_camera_intrinsics, + old_size=(self.H, self.W), new_size=(self.H, self.W) + ) + + n_frames_total = view_cameras_w2cs.shape[1] + num_ar_iterations = (n_frames_total - overlap_frames) // (self.sample_n_frames - overlap_frames) + log.info(f"Generating {n_frames_total} frames will take {num_ar_iterations} auto-regressive iterations") + + # Note: camera trajectory is given by the user, no need to generate it. + log.info(f"Generating frames 0 - {self.sample_n_frames} (out of {n_frames_total} total)...") + rendered_warp_images, rendered_warp_masks = self.cache.render_cache( + view_cameras_w2cs[:, 0:self.sample_n_frames], + view_camera_intrinsics[:, 0:self.sample_n_frames], + start_frame_idx=0, + ) + + all_rendered_warps = [] + all_predicted_depth = [] + if save_buffer: + all_rendered_warps.append(rendered_warp_images.clone().cpu()) + + current_prompt = self.args.prompt + if current_prompt is None and self.args.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + return + + + # Generate video + starting_frame = self.seeding_image + if cache_is_multiframe: + starting_frame = starting_frame[0].unsqueeze(0) + + generated_output = self.pipeline.generate( + prompt=current_prompt, + image_path=starting_frame, + negative_prompt=self.args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + return + video, _ = generated_output + + + def depth_for_frame(frame: np.ndarray | torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + last_frame_hwc_0_255 = torch.tensor(frame, device=self.device_with_rank) + pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] + + pred_depth, pred_mask = _predict_moge_depth_from_tensor( + pred_image_for_depth_chw_0_1, self.moge_model + ) + return pred_depth, pred_mask, pred_image_for_depth_chw_0_1 + + + # We predict depth either if we need it (multi-round generation without depth in the cache), + # or if the user requested it explicitly. + need_depth_of_latest_frame = return_estimated_depths or (num_ar_iterations > 1 and not cache_is_multiframe) + if need_depth_of_latest_frame: + pred_depth, _, pred_image_for_depth_chw_0_1 = depth_for_frame(video[-1]) + + if return_estimated_depths: + # For easier indexing, we include entries even for the frames for which we don't predict + # depth. Since the results will be transmitted in compressed format, this hopefully + # shouldn't take up any additional bandwidth. + depths_batch_0 = np.full((video.shape[0], 1, self.H, self.W), fill_value=np.nan, + dtype=np.float32) + depths_batch_0[-1, ...] = pred_depth.cpu().numpy() + all_predicted_depth.append(depths_batch_0) + del depths_batch_0 + + + # Autoregressive generation (if needed) + for num_iter in range(1, num_ar_iterations): + # Overlap by `overlap_frames` frames + start_frame_idx = num_iter * (self.sample_n_frames - overlap_frames) + end_frame_idx = start_frame_idx + self.sample_n_frames + log.info(f"Generating frames {start_frame_idx} - {end_frame_idx} (out of {n_frames_total} total)...") + + if cache_is_multiframe: + # Nothing much to do, we assume that depth is alraedy provided and + # all frames of the seeding video are already in the cache. + pred_image_for_depth_chw_0_1 = torch.tensor( + video[-1], device=self.device_with_rank + ).permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] + + else: + self.cache.update_cache( + new_image=pred_image_for_depth_chw_0_1.unsqueeze(0) * 2 - 1, # (B,C,H,W) range [-1,1] + new_depth=pred_depth, # (1,1,H,W) + # new_mask=pred_mask, # (1,1,H,W) + new_w2c=view_cameras_w2cs[:, start_frame_idx], + new_intrinsics=view_camera_intrinsics[:, start_frame_idx], + ) + + current_segment_w2cs = view_cameras_w2cs[:, start_frame_idx:end_frame_idx] + current_segment_intrinsics = view_camera_intrinsics[:, start_frame_idx:end_frame_idx] + + cache_start_frame_idx = 0 + if cache_is_multiframe: + # If requesting more frames than are available in the cache, + # freeze (hold) on the last batch of frames. + cache_start_frame_idx = min( + start_frame_idx, + self.cache.input_frame_count() - (end_frame_idx - start_frame_idx) + ) + + rendered_warp_images, rendered_warp_masks = self.cache.render_cache( + current_segment_w2cs, + current_segment_intrinsics, + start_frame_idx=cache_start_frame_idx, + ) + + if save_buffer: + all_rendered_warps.append(rendered_warp_images[:, overlap_frames:].clone().cpu()) + + pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 # (B,C,T,H,W), range [-1,1] + generated_output = self.pipeline.generate( + prompt=current_prompt, + image_path=pred_image_for_depth_bcthw_minus1_1, + negative_prompt=self.args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + video_new, _ = generated_output + + video = np.concatenate([video, video_new[overlap_frames:]], axis=0) + + # Prepare depth prediction for the next AR iteration. + need_depth_of_latest_frame = return_estimated_depths or ((num_iter < num_ar_iterations - 1) and not cache_is_multiframe) + if need_depth_of_latest_frame: + # Either we don't have depth (e.g. single-image seeding), or the user requested + # depth to be returned explicitly. + pred_depth, _, pred_image_for_depth_chw_0_1 = depth_for_frame(video_new[-1]) + if return_estimated_depths: + depths_batch_i = np.full((video_new.shape[0] - overlap_frames, 1, self.H, self.W), + fill_value=np.nan, dtype=np.float32) + depths_batch_i[-1, ...] = pred_depth.cpu().numpy() + all_predicted_depth.append(depths_batch_i) + del depths_batch_i + + + if is_rank0(): + # Final video processing + final_video_to_save = video + final_width = self.args.width + + if save_buffer and all_rendered_warps: + squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] # Each is (T_chunk, n_i, C, H, W) + + if squeezed_warps: + n_max = max(t.shape[1] for t in squeezed_warps) + + padded_t_list = [] + for sq_t in squeezed_warps: + # sq_t shape: (T_chunk, n_i, C, H, W) + current_n_i = sq_t.shape[1] + padding_needed_dim1 = n_max - current_n_i + + pad_spec = (0,0, # W + 0,0, # H + 0,0, # C + 0,padding_needed_dim1, # n_i + 0,0) # T_chunk + padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0) + padded_t_list.append(padded_t) + + full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0) + + T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape + buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4) + buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim) + buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0 + buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8) + buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1)) + + final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2) + final_width = self.args.width * (1 + n_max) + log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}") + + else: + log.info("No warp buffers to save.") + + # Save video + save_video( + video=final_video_to_save, + fps=self.pipeline.fps, + H=self.args.height, + W=final_width, + video_save_quality=video_save_quality, + video_save_path=video_save_path, + ) + log.info(f"Saved video to {video_save_path}") + + + if return_estimated_depths: + predicted_depth = np.concatenate(all_predicted_depth, axis=0) + else: + predicted_depth = None + + + # Currently `video` is [n_frames, height, width, channels]. + # Return as [1, n_frames, channels, height, width] for consistency with other codebases. + video = video.transpose(0, 3, 1, 2)[None, ...] + # Depth is returned as [n_frames, channels, height, width]. + + # TODO: handle overlap + rendered_warp_images_no_overlap = rendered_warp_images + video_no_overlap = video + return { + "rendered_warp_images": rendered_warp_images, + "video": video, + "rendered_warp_images_no_overlap": rendered_warp_images_no_overlap, + "video_no_overlap": video_no_overlap, + "predicted_depth": predicted_depth, + "video_save_path": video_save_path, + } + + # -------------------- + + def prepare_camera_for_inference(self, view_cameras: np.ndarray, view_camera_intrinsics: np.ndarray, + old_size: tuple[int, int], new_size: tuple[int, int]): + """Old and new sizes should be given as (height, width).""" + if isinstance(view_cameras, np.ndarray): + view_cameras = torch.from_numpy(view_cameras).float().contiguous() + if view_cameras.ndim == 3: + view_cameras = view_cameras.unsqueeze(dim=0) + + if isinstance(view_camera_intrinsics, np.ndarray): + view_camera_intrinsics = torch.from_numpy(view_camera_intrinsics).float().contiguous() + + view_camera_intrinsics = resize_intrinsics(view_camera_intrinsics, old_size, new_size) + view_camera_intrinsics = view_camera_intrinsics.unsqueeze(dim=0) + assert view_camera_intrinsics.ndim == 4 + + return view_cameras.to(device_with_rank(self.device_with_rank)), \ + view_camera_intrinsics.to(device_with_rank(self.device_with_rank)) + + + def get_cache_input_depths(self) -> torch.Tensor | None: + if self.cache is None: + return None + return self.cache.input_depth + + @property + def W(self) -> int: + return self.args.width + + @property + def H(self) -> int: + return self.args.height + + + def clear_cache(self) -> None: + self.cache = None + self.model_was_seeded = False + + + def cleanup(self) -> None: + if self.args.num_gpus > 1: + rank = get_rank() + log.info(f"Model cleanup: destroying model parallel group on rank={rank}.", + rank0_only=False) + from megatron.core import parallel_state + parallel_state.destroy_model_parallel() + + import torch.distributed as dist + dist.destroy_process_group() + + log.info(f"Destroyed model parallel group on rank={rank}.", rank0_only=False) + else: + log.info("Model cleanup: nothing to do (no parallelism).", rank0_only=False) diff --git a/cosmos_predict1/diffusion/inference/gen3c_pipeline.py b/cosmos_predict1/diffusion/inference/gen3c_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..df892bdf65ca38b82be46ff222ea547702296efc --- /dev/null +++ b/cosmos_predict1/diffusion/inference/gen3c_pipeline.py @@ -0,0 +1,256 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import ( + generate_world_from_video, + get_video_batch, + load_model_by_config, +) +from cosmos_predict1.diffusion.model.model_gen3c import DiffusionGen3CModel +from cosmos_predict1.diffusion.inference.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline +from cosmos_predict1.utils import log + +class Gen3cPipeline(DiffusionVideo2WorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 0, + ): + """Initialize diffusion world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + """ + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + enable_prompt_upsampler=enable_prompt_upsampler, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + num_input_frames=1, + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionGen3CModel, + ) + + def generate( + self, + prompt: str, + image_path: str, + rendered_warp_images: torch.Tensor, + rendered_warp_masks: torch.Tensor, + negative_prompt: Optional[str] = None, + ) -> Any: + """Generate video from text prompt and optional image. + + Pipeline steps: + 1. Run safety checks on input prompt + 2. Enhance prompt using upsampler if enabled + 3. Run safety checks on upsampled prompt if applicable + 4. Convert prompt to embeddings + 5. Generate video frames using diffusion + 6. Run safety checks and apply face blur on generated video frames + + Args: + prompt: Text description of desired video + image_ path: Path to conditioning image + rendered_warp_images: Rendered warp images + rendered_warp_masks: Rendered warp masks + negative_prompt: Optional text to guide what not to generate + + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + if type(image_path) == str: + log.info(f"Run with image path: {image_path}") + log.info(f"Run with negative prompt: {negative_prompt}") + log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}") + + log.info(f"Run with prompt: {prompt}") + if not self.disable_guardrail: + log.info(f"Run guardrail on {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical(f"Input {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt is not safe") + return None + log.info(f"Pass guardrail on {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt") + else: + log.info("Not running guardrail") + + log.info("Run text embedding on prompt") + if negative_prompt: + prompts = [prompt, negative_prompt] + else: + prompts = [prompt] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + prompt_embedding = prompt_embeddings[0] + negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + image_or_video_path=image_path, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + log.info("Finish generation") + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + video = self._run_guardrail_on_video_with_offload(video) + if video is None: + log.critical("Generated video is not safe") + return None + log.info("Pass guardrail on generated video") + + return video, prompt + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + image_or_video_path: str, + rendered_warp_images: torch.Tensor, + rendered_warp_masks: torch.Tensor, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> Any: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from T5 encoder + image_or_video_path: Path to conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + np.ndarray: Generated world representation as numpy array + """ + if self.offload_tokenizer: + self._load_tokenizer() + + condition_latent = self._run_tokenizer_encoding(image_or_video_path) + + if self.offload_network: + self._load_network() + + sample = self._run_model(prompt_embedding, condition_latent, rendered_warp_images, rendered_warp_masks, negative_prompt_embedding) + + if self.offload_network: + self._offload_network() + + sample = self._run_tokenizer_decoding(sample) + + if self.offload_tokenizer: + self._offload_tokenizer() + + return sample + + def _run_model( + self, + embedding: torch.Tensor, + condition_latent: torch.Tensor, + rendered_warp_images: torch.Tensor, + rendered_warp_masks: torch.Tensor, + negative_prompt_embedding: torch.Tensor | None = None, + ) -> Any: + data_batch, state_shape = get_video_batch( + model=self.model, + prompt_embedding=embedding, + negative_prompt_embedding=negative_prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames, + ) + data_batch["condition_state"] = rendered_warp_images + data_batch["condition_state_mask"] = rendered_warp_masks + # Generate video frames + video = generate_world_from_video( + model=self.model, + state_shape=self.model.state_shape, + is_negative_prompt=True, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + condition_latent=condition_latent, + num_input_frames=self.num_input_frames, + ) + + return video diff --git a/cosmos_predict1/diffusion/inference/gen3c_single_image.py b/cosmos_predict1/diffusion/inference/gen3c_single_image.py new file mode 100644 index 0000000000000000000000000000000000000000..856ffffd7d228cb171e3d61fbdd09a6cbfd0eabe --- /dev/null +++ b/cosmos_predict1/diffusion/inference/gen3c_single_image.py @@ -0,0 +1,492 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import cv2 +from moge.model.v1 import MoGeModel +import torch +import numpy as np +from cosmos_predict1.diffusion.inference.inference_utils import ( + add_common_arguments, + check_input_frames, + validate_args, +) +from cosmos_predict1.diffusion.inference.gen3c_pipeline import Gen3cPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video +from cosmos_predict1.diffusion.inference.cache_3d import Cache3D_Buffer +from cosmos_predict1.diffusion.inference.camera_utils import generate_camera_trajectory +import torch.nn.functional as F +torch.enable_grad(False) + +def create_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) # TODO: do we need this? + parser.add_argument( + "--input_image_path", + type=str, + help="Input image path for generating a single video", + ) + parser.add_argument( + "--trajectory", + type=str, + choices=[ + "left", + "right", + "up", + "down", + "zoom_in", + "zoom_out", + "clockwise", + "counterclockwise", + "none", + ], + default="left", + help="Select a trajectory type from the available options (default: original)", + ) + parser.add_argument( + "--camera_rotation", + type=str, + choices=["center_facing", "no_rotation", "trajectory_aligned"], + default="center_facing", + help="Controls camera rotation during movement: center_facing (rotate to look at center), no_rotation (keep orientation), or trajectory_aligned (rotate in the direction of movement)", + ) + parser.add_argument( + "--movement_distance", + type=float, + default=0.3, + help="Distance of the camera from the center of the scene", + ) + parser.add_argument( + "--noise_aug_strength", + type=float, + default=0.0, + help="Strength of noise augmentation on warped frames", + ) + parser.add_argument( + "--save_buffer", + action="store_true", + help="If set, save the warped images (buffer) side by side with the output video.", + ) + parser.add_argument( + "--filter_points_threshold", + type=float, + default=0.05, + help="If set, filter the points continuity of the warped images.", + ) + parser.add_argument( + "--foreground_masking", + action="store_true", + help="If set, use foreground masking for the warped images.", + ) + return parser + +def parse_arguments() -> argparse.Namespace: + parser = create_parser() + return parser.parse_args() + + +def validate_args(args): + assert args.num_video_frames is not None, "num_video_frames must be provided" + assert (args.num_video_frames - 1) % 120 == 0, "num_video_frames must be 121, 241, 361, ... (N*120+1)" + +def _predict_moge_depth(current_image_path: str | np.ndarray, + target_h: int, target_w: int, + device: torch.device, moge_model: MoGeModel): + """Handles MoGe depth prediction for a single image. + + If the image is directly provided as a NumPy array, it should have shape [H, W, C], + where the channels are RGB and the pixel values are in [0..255]. + """ + + if isinstance(current_image_path, str): + input_image_bgr = cv2.imread(current_image_path) + if input_image_bgr is None: + raise FileNotFoundError(f"Input image not found: {current_image_path}") + input_image_rgb = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB) + else: + input_image_rgb = current_image_path + del current_image_path + + depth_pred_h, depth_pred_w = 720, 1280 + + input_image_for_depth_resized = cv2.resize(input_image_rgb, (depth_pred_w, depth_pred_h)) + input_image_for_depth_tensor_chw = torch.tensor(input_image_for_depth_resized / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1) + moge_output_full = moge_model.infer(input_image_for_depth_tensor_chw) + moge_depth_hw_full = moge_output_full["depth"] + moge_intrinsics_33_full_normalized = moge_output_full["intrinsics"] + moge_mask_hw_full = moge_output_full["mask"] + + moge_depth_hw_full = torch.where(moge_mask_hw_full==0, torch.tensor(1000.0, device=moge_depth_hw_full.device), moge_depth_hw_full) + moge_intrinsics_33_full_pixel = moge_intrinsics_33_full_normalized.clone() + moge_intrinsics_33_full_pixel[0, 0] *= depth_pred_w + moge_intrinsics_33_full_pixel[1, 1] *= depth_pred_h + moge_intrinsics_33_full_pixel[0, 2] *= depth_pred_w + moge_intrinsics_33_full_pixel[1, 2] *= depth_pred_h + + # Calculate scaling factor for height + height_scale_factor = target_h / depth_pred_h + width_scale_factor = target_w / depth_pred_w + + # Resize depth map, mask, and image tensor + # Resizing depth: (H, W) -> (1, 1, H, W) for interpolate, then squeeze + moge_depth_hw = F.interpolate( + moge_depth_hw_full.unsqueeze(0).unsqueeze(0), + size=(target_h, target_w), + mode='bilinear', + align_corners=False + ).squeeze(0).squeeze(0) + + # Resizing mask: (H, W) -> (1, 1, H, W) for interpolate, then squeeze + moge_mask_hw = F.interpolate( + moge_mask_hw_full.unsqueeze(0).unsqueeze(0).to(torch.float32), + size=(target_h, target_w), + mode='nearest', # Using nearest neighbor for binary mask + ).squeeze(0).squeeze(0).to(torch.bool) + + # Resizing image tensor: (C, H, W) -> (1, C, H, W) for interpolate, then squeeze + input_image_tensor_chw_target_res = F.interpolate( + input_image_for_depth_tensor_chw.unsqueeze(0), + size=(target_h, target_w), + mode='bilinear', + align_corners=False + ).squeeze(0) + + moge_image_b1chw_float = input_image_tensor_chw_target_res.unsqueeze(0).unsqueeze(1) * 2 - 1 + + moge_intrinsics_33 = moge_intrinsics_33_full_pixel.clone() + # Adjust intrinsics for resized height + moge_intrinsics_33[1, 1] *= height_scale_factor # fy + moge_intrinsics_33[1, 2] *= height_scale_factor # cy + moge_intrinsics_33[0, 0] *= width_scale_factor # fx + moge_intrinsics_33[0, 2] *= width_scale_factor # cx + + moge_depth_b11hw = moge_depth_hw.unsqueeze(0).unsqueeze(0).unsqueeze(0) + moge_depth_b11hw = torch.nan_to_num(moge_depth_b11hw, nan=1e4) + moge_depth_b11hw = torch.clamp(moge_depth_b11hw, min=0, max=1e4) + moge_mask_b11hw = moge_mask_hw.unsqueeze(0).unsqueeze(0).unsqueeze(0) + # Prepare initial intrinsics [B, 1, 3, 3] + moge_intrinsics_b133 = moge_intrinsics_33.unsqueeze(0).unsqueeze(0) + initial_w2c_44 = torch.eye(4, dtype=torch.float32, device=device) + moge_initial_w2c_b144 = initial_w2c_44.unsqueeze(0).unsqueeze(0) + + return ( + moge_image_b1chw_float, + moge_depth_b11hw, + moge_mask_b11hw, + moge_initial_w2c_b144, + moge_intrinsics_b133, + ) + +def _predict_moge_depth_from_tensor( + image_tensor_chw_0_1: torch.Tensor, # Shape (C, H_input, W_input), range [0,1] + moge_model: MoGeModel +): + """Handles MoGe depth prediction from an image tensor.""" + moge_output_full = moge_model.infer(image_tensor_chw_0_1) + moge_depth_hw_full = moge_output_full["depth"] # (moge_inf_h, moge_inf_w) + moge_mask_hw_full = moge_output_full["mask"] # (moge_inf_h, moge_inf_w) + + moge_depth_11hw = moge_depth_hw_full.unsqueeze(0).unsqueeze(0) + moge_depth_11hw = torch.nan_to_num(moge_depth_11hw, nan=1e4) + moge_depth_11hw = torch.clamp(moge_depth_11hw, min=0, max=1e4) + moge_mask_11hw = moge_mask_hw_full.unsqueeze(0).unsqueeze(0) + moge_depth_11hw = torch.where(moge_mask_11hw==0, torch.tensor(1000.0, device=moge_depth_11hw.device), moge_depth_11hw) + + return moge_depth_11hw, moge_mask_11hw + +def demo(args): + """Run video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "video2world" + validate_args(args) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize video2world generation model pipeline + pipeline = Gen3cPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name="Gen3C-Cosmos-7B", + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=121, + seed=args.seed, + ) + + frame_buffer_max = pipeline.model.frame_buffer_max + generator = torch.Generator(device=device).manual_seed(args.seed) + sample_n_frames = pipeline.model.chunk_size + moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": args.prompt, "visual_input": args.input_image_path}] + + os.makedirs(os.path.dirname(args.video_save_folder), exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None and args.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + continue + current_image_path = input_dict.get("visual_input", None) + if current_image_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + + # Check input frames + if not check_input_frames(current_image_path, 1): + print(f"Input image {current_image_path} is not valid, skipping.") + continue + + # load image, predict depth and initialize 3D cache + ( + moge_image_b1chw_float, + moge_depth_b11hw, + moge_mask_b11hw, + moge_initial_w2c_b144, + moge_intrinsics_b133, + ) = _predict_moge_depth( + current_image_path, args.height, args.width, device, moge_model + ) + + cache = Cache3D_Buffer( + frame_buffer_max=frame_buffer_max, + generator=generator, + noise_aug_strength=args.noise_aug_strength, + input_image=moge_image_b1chw_float[:, 0].clone(), # [B, C, H, W] + input_depth=moge_depth_b11hw[:, 0], # [B, 1, H, W] + # input_mask=moge_mask_b11hw[:, 0], # [B, 1, H, W] + input_w2c=moge_initial_w2c_b144[:, 0], # [B, 4, 4] + input_intrinsics=moge_intrinsics_b133[:, 0],# [B, 3, 3] + filter_points_threshold=args.filter_points_threshold, + foreground_masking=args.foreground_masking, + ) + + initial_cam_w2c_for_traj = moge_initial_w2c_b144[0, 0] + initial_cam_intrinsics_for_traj = moge_intrinsics_b133[0, 0] + + # Generate camera trajectory using the new utility function + try: + generated_w2cs, generated_intrinsics = generate_camera_trajectory( + trajectory_type=args.trajectory, + initial_w2c=initial_cam_w2c_for_traj, + initial_intrinsics=initial_cam_intrinsics_for_traj, + num_frames=args.num_video_frames, + movement_distance=args.movement_distance, + camera_rotation=args.camera_rotation, + center_depth=1.0, + device=device.type, + ) + except (ValueError, NotImplementedError) as e: + log.critical(f"Failed to generate trajectory: {e}") + continue + + log.info(f"Generating 0 - {sample_n_frames} frames") + rendered_warp_images, rendered_warp_masks = cache.render_cache( + generated_w2cs[:, 0:sample_n_frames], + generated_intrinsics[:, 0:sample_n_frames], + ) + + all_rendered_warps = [] + if args.save_buffer: + all_rendered_warps.append(rendered_warp_images.clone().cpu()) + + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_path=current_image_path, + negative_prompt=args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + video, prompt = generated_output + + num_ar_iterations = (generated_w2cs.shape[1] - 1) // (sample_n_frames - 1) + for num_iter in range(1, num_ar_iterations): + start_frame_idx = num_iter * (sample_n_frames - 1) # Overlap by 1 frame + end_frame_idx = start_frame_idx + sample_n_frames + + log.info(f"Generating {start_frame_idx} - {end_frame_idx} frames") + + last_frame_hwc_0_255 = torch.tensor(video[-1], device=device) + pred_image_for_depth_chw_0_1 = last_frame_hwc_0_255.permute(2, 0, 1) / 255.0 # (C,H,W), range [0,1] + + pred_depth, pred_mask = _predict_moge_depth_from_tensor( + pred_image_for_depth_chw_0_1, moge_model + ) + + cache.update_cache( + new_image=pred_image_for_depth_chw_0_1.unsqueeze(0) * 2 - 1, # (B,C,H,W) range [-1,1] + new_depth=pred_depth, # (1,1,H,W) + # new_mask=pred_mask, # (1,1,H,W) + new_w2c=generated_w2cs[:, start_frame_idx], + new_intrinsics=generated_intrinsics[:, start_frame_idx], + ) + current_segment_w2cs = generated_w2cs[:, start_frame_idx:end_frame_idx] + current_segment_intrinsics = generated_intrinsics[:, start_frame_idx:end_frame_idx] + rendered_warp_images, rendered_warp_masks = cache.render_cache( + current_segment_w2cs, + current_segment_intrinsics, + ) + + if args.save_buffer: + all_rendered_warps.append(rendered_warp_images[:, 1:].clone().cpu()) + + + pred_image_for_depth_bcthw_minus1_1 = pred_image_for_depth_chw_0_1.unsqueeze(0).unsqueeze(2) * 2 - 1 # (B,C,T,H,W), range [-1,1] + generated_output = pipeline.generate( + prompt=current_prompt, + image_path=pred_image_for_depth_bcthw_minus1_1, + negative_prompt=args.negative_prompt, + rendered_warp_images=rendered_warp_images, + rendered_warp_masks=rendered_warp_masks, + ) + video_new, prompt = generated_output + video = np.concatenate([video, video_new[1:]], axis=0) + + # Final video processing + final_video_to_save = video + final_width = args.width + + if args.save_buffer and all_rendered_warps: + squeezed_warps = [t.squeeze(0) for t in all_rendered_warps] # Each is (T_chunk, n_i, C, H, W) + + if squeezed_warps: + n_max = max(t.shape[1] for t in squeezed_warps) + + padded_t_list = [] + for sq_t in squeezed_warps: + # sq_t shape: (T_chunk, n_i, C, H, W) + current_n_i = sq_t.shape[1] + padding_needed_dim1 = n_max - current_n_i + + pad_spec = (0,0, # W + 0,0, # H + 0,0, # C + 0,padding_needed_dim1, # n_i + 0,0) # T_chunk + padded_t = F.pad(sq_t, pad_spec, mode='constant', value=-1.0) + padded_t_list.append(padded_t) + + full_rendered_warp_tensor = torch.cat(padded_t_list, dim=0) + + T_total, _, C_dim, H_dim, W_dim = full_rendered_warp_tensor.shape + buffer_video_TCHnW = full_rendered_warp_tensor.permute(0, 2, 3, 1, 4) + buffer_video_TCHWstacked = buffer_video_TCHnW.contiguous().view(T_total, C_dim, H_dim, n_max * W_dim) + buffer_video_TCHWstacked = (buffer_video_TCHWstacked * 0.5 + 0.5) * 255.0 + buffer_numpy_TCHWstacked = buffer_video_TCHWstacked.cpu().numpy().astype(np.uint8) + buffer_numpy_THWC = np.transpose(buffer_numpy_TCHWstacked, (0, 2, 3, 1)) + + final_video_to_save = np.concatenate([buffer_numpy_THWC, final_video_to_save], axis=2) + final_width = args.width * (1 + n_max) + log.info(f"Concatenating video with {n_max} warp buffers. Final video width will be {final_width}") + else: + log.info("No warp buffers to save.") + + + video_save_path = os.path.join( + args.video_save_folder, + f"{i if args.batch_input_path else args.video_save_name}.mp4" + ) + + os.makedirs(os.path.dirname(video_save_path), exist_ok=True) + + # Save video + save_video( + video=final_video_to_save, + fps=args.fps, + H=args.height, + W=final_width, + video_save_quality=5, + video_save_path=video_save_path, + ) + log.info(f"Saved video to {video_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + if args.prompt is None: + args.prompt = "" + args.disable_guardrail = True + args.disable_prompt_upsampler = True + demo(args) diff --git a/cosmos_predict1/diffusion/inference/inference_utils.py b/cosmos_predict1/diffusion/inference/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..61e9563a08150555a07e140d71894a047be07d09 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/inference_utils.py @@ -0,0 +1,936 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +from contextlib import contextmanager +from typing import List, NamedTuple, Optional, Tuple + +import einops +import imageio +import numpy as np +import omegaconf.errors +import torch +import torchvision.transforms.functional as transforms_F +from omegaconf import OmegaConf + +from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel +from cosmos_predict1.diffusion.model.model_v2w_multiview import DiffusionMultiviewV2WModel +from cosmos_predict1.diffusion.model.model_world_interpolator import DiffusionWorldInterpolatorWModel +from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel +from cosmos_predict1.utils import log +from cosmos_predict1.utils.config_helper import get_config_module, override +from cosmos_predict1.utils.io import load_from_fileobj + +TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) +if TORCH_VERSION >= (1, 11): + from torch.ao import quantization + from torch.ao.quantization import FakeQuantizeBase, ObserverBase +elif ( + TORCH_VERSION >= (1, 8) + and hasattr(torch.quantization, "FakeQuantizeBase") + and hasattr(torch.quantization, "ObserverBase") +): + from torch import quantization + from torch.quantization import FakeQuantizeBase, ObserverBase + +DEFAULT_AUGMENT_SIGMA = 0.001 + + +def add_common_arguments(parser): + """Add common command line arguments for text2world and video2world generation. + + Args: + parser (ArgumentParser): Argument parser to add arguments to + + The arguments include: + - checkpoint_dir: Base directory containing model weights + - tokenizer_dir: Directory containing tokenizer weights + - video_save_name: Output video filename for single video generation + - video_save_folder: Output directory for batch video generation + - prompt: Text prompt for single video generation + - batch_input_path: Path to JSONL file with input prompts for batch video generation + - negative_prompt: Text prompt describing undesired attributes + - num_steps: Number of diffusion sampling steps + - guidance: Classifier-free guidance scale + - num_video_frames: Number of frames to generate + - height/width: Output video dimensions + - fps: Output video frame rate + - seed: Random seed for reproducibility + - Various model offloading flags + """ + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--tokenizer_dir", + type=str, + default="Cosmos-Tokenize1-CV8x8x8-720p", + help="Tokenizer weights directory relative to checkpoint_dir", + ) + parser.add_argument( + "--video_save_name", + type=str, + default="output", + help="Output filename for generating a single video", + ) + parser.add_argument( + "--video_save_folder", + type=str, + default="outputs/", + help="Output folder for generating a batch of videos", + ) + parser.add_argument( + "--prompt", + type=str, + help="Text prompt for generating a single video", + ) + parser.add_argument( + "--batch_input_path", + type=str, + help="Path to a JSONL file of input prompts for generating a batch of videos", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special " + "effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and " + "flickering. Overall, the video is of poor quality.", + help="Negative prompt for the video", + ) + parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps") + parser.add_argument("--guidance", type=float, default=7, help="Guidance scale value") + parser.add_argument( + "--num_video_frames", + type=int, + default=121, + # choices=[8 * n + 1 for n in range(16)] + [10, 117], + help="Number of video frames to sample", + ) + parser.add_argument("--height", type=int, default=704, help="Height of video to sample") + parser.add_argument("--width", type=int, default=1280, help="Width of video to sample") + parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs used to run inference in parallel.") + parser.add_argument( + "--disable_prompt_upsampler", + action="store_true", + help="Disable prompt upsampling", + ) + parser.add_argument( + "--offload_diffusion_transformer", + action="store_true", + help="Offload DiT after inference", + ) + parser.add_argument( + "--offload_tokenizer", + action="store_true", + help="Offload tokenizer after inference", + ) + parser.add_argument( + "--offload_text_encoder_model", + action="store_true", + help="Offload text encoder model after inference", + ) + parser.add_argument( + "--offload_prompt_upsampler", + action="store_true", + help="Offload prompt upsampler after inference", + ) + parser.add_argument( + "--offload_guardrail_models", + action="store_true", + help="Offload guardrail models after inference", + ) + parser.add_argument( + "--disable_guardrail", + action="store_true", + help="Disable guardrail models", + ) + + +# Function to fully remove an argument +def remove_argument(parser, arg_name): + # Get a list of actions to remove + actions_to_remove = [action for action in parser._actions if action.dest == arg_name] + + for action in actions_to_remove: + # Remove action from parser._actions + parser._actions.remove(action) + + # Remove option strings + for option_string in action.option_strings: + parser._option_string_actions.pop(option_string, None) + + +def validate_args(args: argparse.Namespace, inference_type: str) -> None: + """Validate command line arguments for text2world and video2world generation.""" + assert inference_type in [ + "text2world", + "video2world", + "world_interpolator", + ], "Invalid inference_type, must be 'text2world' or 'video2world'" + + # Validate prompt/image/video args for single or batch generation + if inference_type == "text2world" or (inference_type == "video2world" and args.disable_prompt_upsampler): + assert args.prompt or args.batch_input_path, "--prompt or --batch_input_path must be provided." + if (inference_type == "video2world" or inference_type == "world_interpolator") and not args.batch_input_path: + assert ( + args.input_image_or_video_path + ), "--input_image_or_video_path must be provided for single video generation." + + +class _IncompatibleKeys( + NamedTuple( + "IncompatibleKeys", + [ + ("missing_keys", List[str]), + ("unexpected_keys", List[str]), + ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), + ], + ) +): + pass + + +def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: + """Load a model checkpoint with non-strict matching, handling shape mismatches. + + Args: + model (torch.nn.Module): Model to load weights into + checkpoint_state_dict (dict): State dict from checkpoint + + Returns: + _IncompatibleKeys: Named tuple containing: + - missing_keys: Keys present in model but missing from checkpoint + - unexpected_keys: Keys present in checkpoint but not in model + - incorrect_shapes: Keys with mismatched tensor shapes + + The function handles special cases like: + - Uninitialized parameters + - Quantization observers + - TransformerEngine FP8 states + """ + # workaround https://github.com/pytorch/pytorch/issues/24139 + model_state_dict = model.state_dict() + incorrect_shapes = [] + for k in list(checkpoint_state_dict.keys()): + if k in model_state_dict: + if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 + log.debug(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") + continue + model_param = model_state_dict[k] + # Allow mismatch for uninitialized parameters + if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): + continue + if not isinstance(model_param, torch.Tensor): + raise ValueError( + f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." + ) + + shape_model = tuple(model_param.shape) + shape_checkpoint = tuple(checkpoint_state_dict[k].shape) + if shape_model != shape_checkpoint: + has_observer_base_classes = ( + TORCH_VERSION >= (1, 8) + and hasattr(quantization, "ObserverBase") + and hasattr(quantization, "FakeQuantizeBase") + ) + if has_observer_base_classes: + # Handle the special case of quantization per channel observers, + # where buffer shape mismatches are expected. + def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: + # foo.bar.param_or_buffer_name -> [foo, bar] + key_parts = key.split(".")[:-1] + cur_module = model + for key_part in key_parts: + cur_module = getattr(cur_module, key_part) + return cur_module + + cls_to_skip = ( + ObserverBase, + FakeQuantizeBase, + ) + target_module = _get_module_for_key(model, k) + if isinstance(target_module, cls_to_skip): + # Do not remove modules with expected shape mismatches + # them from the state_dict loading. They have special logic + # in _load_from_state_dict to handle the mismatches. + continue + + incorrect_shapes.append((k, shape_checkpoint, shape_model)) + checkpoint_state_dict.pop(k) + incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) + # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling + missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] + unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] + return _IncompatibleKeys( + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + incorrect_shapes=incorrect_shapes, + ) + + +@contextmanager +def skip_init_linear(): + # skip init of nn.Linear + orig_reset_parameters = torch.nn.Linear.reset_parameters + torch.nn.Linear.reset_parameters = lambda x: x + xavier_uniform_ = torch.nn.init.xavier_uniform_ + torch.nn.init.xavier_uniform_ = lambda x: x + yield + torch.nn.Linear.reset_parameters = orig_reset_parameters + torch.nn.init.xavier_uniform_ = xavier_uniform_ + + +def load_model_by_config( + config_job_name, + config_file="projects/cosmos_video/config/config.py", + model_class=DiffusionT2WModel, +): + config_module = get_config_module(config_file) + config = importlib.import_module(config_module).make_config() + + config = override(config, ["--", f"experiment={config_job_name}"]) + + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + # Initialize model + with skip_init_linear(): + model = model_class(config.model) + return model + + +def load_network_model(model: DiffusionT2WModel, ckpt_path: str): + with skip_init_linear(): + model.set_up_model() + try: + net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + except Exception: + # Posttrained models can be loaded with weights_only=False + net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) + if "model" in net_state_dict: + model_state_dict = net_state_dict["model"] + if "ema" in net_state_dict and model.config.peft_control and model.config.peft_control.enabled: + ema_state_dict = net_state_dict["ema"] + # Convert ema state_dict to model state_dict by replacing "-" with "." + ema_state_dict = {k.replace("-", "."): v for k, v in ema_state_dict.items()} + model_state_dict.update(ema_state_dict) + net_state_dict = model_state_dict + else: + net_state_dict = model_state_dict + + log.debug(non_strict_load_model(model.model, net_state_dict)) + model.cuda() + + +def load_tokenizer_model(model: DiffusionT2WModel, tokenizer_dir: str): + with skip_init_linear(): + model.set_up_tokenizer(tokenizer_dir) + model.cuda() + + +def prepare_data_batch( + height: int, + width: int, + num_frames: int, + fps: int, + prompt_embedding: torch.Tensor, + negative_prompt_embedding: Optional[torch.Tensor] = None, +): + """Prepare input batch tensors for video generation. + + Args: + height (int): Height of video frames + width (int): Width of video frames + num_frames (int): Number of frames to generate + fps (int): Frames per second + prompt_embedding (torch.Tensor): Encoded text prompt embeddings + negative_prompt_embedding (torch.Tensor, optional): Encoded negative prompt embeddings + + Returns: + dict: Batch dictionary containing: + - video: Zero tensor of target video shape + - t5_text_mask: Attention mask for text embeddings + - image_size: Target frame dimensions + - fps: Target frame rate + - num_frames: Number of frames + - padding_mask: Frame padding mask + - t5_text_embeddings: Prompt embeddings + - neg_t5_text_embeddings: Negative prompt embeddings (if provided) + - neg_t5_text_mask: Mask for negative embeddings (if provided) + """ + # Create base data batch + data_batch = { + "video": torch.zeros((1, 3, num_frames, height, width), dtype=torch.uint8).cuda(), + "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), + "image_size": torch.tensor([[height, width, height, width]] * 1, dtype=torch.bfloat16).cuda(), + "fps": torch.tensor([fps] * 1, dtype=torch.bfloat16).cuda(), + "num_frames": torch.tensor([num_frames] * 1, dtype=torch.bfloat16).cuda(), + "padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(), + } + + # Handle text embeddings + + t5_embed = prompt_embedding.to(dtype=torch.bfloat16).cuda() + data_batch["t5_text_embeddings"] = t5_embed + + if negative_prompt_embedding is not None: + neg_t5_embed = negative_prompt_embedding.to(dtype=torch.bfloat16).cuda() + data_batch["neg_t5_text_embeddings"] = neg_t5_embed + data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda() + + return data_batch + + +def get_video_batch(model, prompt_embedding, negative_prompt_embedding, height, width, fps, num_video_frames): + """Prepare complete input batch for video generation including latent dimensions. + + Args: + model: Diffusion model instance + prompt_embedding (torch.Tensor): Text prompt embeddings + negative_prompt_embedding (torch.Tensor): Negative prompt embeddings + height (int): Output video height + width (int): Output video width + fps (int): Output video frame rate + num_video_frames (int): Number of frames to generate + + Returns: + tuple: + - data_batch (dict): Complete model input batch + - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression + """ + raw_video_batch = prepare_data_batch( + height=height, + width=width, + num_frames=num_video_frames, + fps=fps, + prompt_embedding=prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + ) + try: + condition_location = model.config.conditioner.video_cond_bool.condition_location + except omegaconf.errors.ConfigAttributeError: + condition_location = None + + # Use condition_location in your logic + if condition_location == "first_and_last_1": + state_shape = [ + model.tokenizer.channel, + model.tokenizer.get_latent_num_frames(num_video_frames - 1) + 1, # +1 for the last frame + height // model.tokenizer.spatial_compression_factor, + width // model.tokenizer.spatial_compression_factor, + ] + else: + state_shape = [ + model.tokenizer.channel, + model.tokenizer.get_latent_num_frames(num_video_frames), + height // model.tokenizer.spatial_compression_factor, + width // model.tokenizer.spatial_compression_factor, + ] + + return raw_video_batch, state_shape + + +def get_video_batch_for_multiview_model( + model, prompt_embedding, height, width, fps, num_video_frames, frame_repeat_negative_condition +): + """Prepare complete input batch for video generation including latent dimensions. + + Args: + model: Diffusion model instance + prompt_embedding (torch.Tensor): Text prompt embeddings + height (int): Output video height + width (int): Output video width + fps (int): Output video frame rate + num_video_frames (int): Number of frames to generate + frame_repeat_negative_condition (int): Number of frames to generate + + Returns: + tuple: + - data_batch (dict): Complete model input batch + - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression + """ + n_views = len(prompt_embedding) + prompt_embedding = einops.rearrange(torch.cat(prompt_embedding), "n t d -> (n t) d").unsqueeze(0) + raw_video_batch = prepare_data_batch( + height=height, + width=width, + num_frames=num_video_frames, + fps=fps, + prompt_embedding=prompt_embedding, + ) + if frame_repeat_negative_condition != -1: + frame_repeat = torch.zeros(n_views) + frame_repeat[-1] = frame_repeat_negative_condition + frame_repeat[-2] = frame_repeat_negative_condition + raw_video_batch["frame_repeat"] = frame_repeat.unsqueeze(0).to(dtype=torch.bfloat16).cuda() + state_shape = [ + model.tokenizer.channel, + model.tokenizer.get_latent_num_frames(int(num_video_frames / n_views)) * n_views, + height // model.tokenizer.spatial_compression_factor, + width // model.tokenizer.spatial_compression_factor, + ] + return raw_video_batch, state_shape + + +def generate_world_from_text( + model: DiffusionT2WModel, + state_shape: list[int], + is_negative_prompt: bool, + data_batch: dict, + guidance: float, + num_steps: int, + seed: int, +): + """Generate video from text prompt using diffusion model. + + Args: + model (DiffusionT2WModel): Text-to-video diffusion model + state_shape (list[int]): Latent state dimensions [C,T,H,W] + is_negative_prompt (bool): Whether negative prompt is provided + data_batch (dict): Model input batch with embeddings + guidance (float): Classifier-free guidance scale + num_steps (int): Number of diffusion sampling steps + seed (int): Random seed for reproducibility + + Returns: + np.ndarray: Generated video frames [T,H,W,C], range [0,255] + + The function: + 1. Initializes random latent with maximum noise + 2. Performs guided diffusion sampling + 3. Decodes latents to pixel space + """ + + # Generate video + sample = model.generate_samples_from_batch( + data_batch, + guidance=guidance, + state_shape=state_shape, + num_steps=num_steps, + is_negative_prompt=is_negative_prompt, + seed=seed, + ) + + return sample + + +def generate_world_from_video( + model: DiffusionV2WModel, + state_shape: list[int], + is_negative_prompt: bool, + data_batch: dict, + guidance: float, + num_steps: int, + seed: int, + condition_latent: torch.Tensor, + num_input_frames: int, +) -> Tuple[np.array, list, list]: + """Generate video using a conditioning video/image input. + + Args: + model (DiffusionV2WModel): The diffusion model instance + state_shape (list[int]): Shape of the latent state [C,T,H,W] + is_negative_prompt (bool): Whether negative prompt is provided + data_batch (dict): Batch containing model inputs including text embeddings + guidance (float): Classifier-free guidance scale for sampling + num_steps (int): Number of diffusion sampling steps + seed (int): Random seed for generation + condition_latent (torch.Tensor): Latent tensor from conditioning video/image file + num_input_frames (int): Number of input frames + + Returns: + np.array: Generated video frames in shape [T,H,W,C], range [0,255] + """ + assert not model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i, "not supported" + augment_sigma = DEFAULT_AUGMENT_SIGMA + + if condition_latent.shape[2] < state_shape[1]: + # Padding condition latent to state shape + b, c, t, h, w = condition_latent.shape + condition_latent = torch.cat( + [ + condition_latent, + condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), + ], + dim=2, + ).contiguous() + num_of_latent_condition = compute_num_latent_frames(model, num_input_frames) + + sample = model.generate_samples_from_batch( + data_batch, + guidance=guidance, + state_shape=state_shape, + num_steps=num_steps, + is_negative_prompt=is_negative_prompt, + seed=seed, + condition_latent=condition_latent, + num_condition_t=num_of_latent_condition, + condition_augment_sigma=augment_sigma, + ) + return sample + + +def read_video_or_image_into_frames_BCTHW( + input_path: str, + input_path_format: str = "mp4", + H: int = None, + W: int = None, + normalize: bool = True, + max_frames: int = -1, + also_return_fps: bool = False, +) -> torch.Tensor: + """Read video or image file and convert to tensor format. + + Args: + input_path (str): Path to input video/image file + input_path_format (str): Format of input file (default: "mp4") + H (int, optional): Height to resize frames to + W (int, optional): Width to resize frames to + normalize (bool): Whether to normalize pixel values to [-1,1] (default: True) + max_frames (int): Maximum number of frames to read (-1 for all frames) + also_return_fps (bool): Whether to return fps along with frames + + Returns: + torch.Tensor | tuple: Video tensor in shape [B,C,T,H,W], optionally with fps if requested + """ + log.debug(f"Reading video from {input_path}") + + loaded_data = load_from_fileobj(input_path, format=input_path_format) + frames, meta_data = loaded_data + if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): + frames = np.array(frames[0]) # HWC, [0,255] + if frames.shape[-1] > 3: # RGBA, set the transparent to white + # Separate the RGB and Alpha channels + rgb_channels = frames[..., :3] + alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1] + + # Create a white background + white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB + + # Blend the RGB channels with the white background based on the alpha channel + frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( + np.uint8 + ) + frames = [frames] + fps = 0 + else: + fps = int(meta_data.get("fps")) + if max_frames != -1: + frames = frames[:max_frames] + input_tensor = np.stack(frames, axis=0) + input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") + if normalize: + input_tensor = input_tensor / 128.0 - 1.0 + input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW + log.debug(f"Raw data shape: {input_tensor.shape}") + if H is not None and W is not None: + input_tensor = transforms_F.resize( + input_tensor, + size=(H, W), # type: ignore + interpolation=transforms_F.InterpolationMode.BICUBIC, + antialias=True, + ) + input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) + if normalize: + input_tensor = input_tensor.to("cuda") + log.debug(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") + if also_return_fps: + return input_tensor, fps + return input_tensor + + +def compute_num_latent_frames(model: DiffusionV2WModel, num_input_frames: int, downsample_factor=8) -> int: + """This function computes the number of latent frames given the number of input frames. + Args: + model (DiffusionV2WModel): video generation model + num_input_frames (int): number of input frames + downsample_factor (int): downsample factor for temporal reduce + Returns: + int: number of latent frames + """ + # First find how many vae chunks are contained with in num_input_frames + num_latent_frames = ( + num_input_frames + // model.tokenizer.video_vae.pixel_chunk_duration + * model.tokenizer.video_vae.latent_chunk_duration + ) + # Then handle the remainder + if num_input_frames % model.tokenizer.video_vae.latent_chunk_duration == 1: + num_latent_frames += 1 + elif num_input_frames % model.tokenizer.video_vae.latent_chunk_duration > 1: + assert ( + num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 + ) % downsample_factor == 0, f"num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 must be divisible by {downsample_factor}" + num_latent_frames += ( + 1 + (num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1) // downsample_factor + ) + + return num_latent_frames + + +def create_condition_latent_from_input_frames( + model: DiffusionV2WModel, + input_frames: torch.Tensor, + num_frames_condition: int = 25, +): + """Create condition latent for video generation from input frames. + + Takes the last num_frames_condition frames from input as conditioning. + + Args: + model (DiffusionV2WModel): Video generation model + input_frames (torch.Tensor): Input video tensor [B,C,T,H,W], range [-1,1] + num_frames_condition (int): Number of frames to use for conditioning + + Returns: + tuple: (condition_latent, encode_input_frames) where: + - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] + - encode_input_frames (torch.Tensor): Padded input frames used for encoding + """ + B, C, T, H, W = input_frames.shape + num_frames_encode = ( + model.tokenizer.pixel_chunk_duration + ) # (model.state_shape[1] - 1) / model.vae.pixel_chunk_duration + 1 + log.debug( + f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" + ) + + log.debug( + f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" + ) + + assert ( + input_frames.shape[2] >= num_frames_condition + ), f"input_frames not enough for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}" + assert ( + num_frames_encode >= num_frames_condition + ), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}" + + # Put the conditioal frames to the begining of the video, and pad the end with zero + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + condition_frames_first = input_frames[:, :, :num_frames_condition] + condition_frames_last = input_frames[:, :, -num_frames_condition:] + padding_frames = condition_frames_first.new_zeros(B, C, num_frames_encode + 1 - 2 * num_frames_condition, H, W) + encode_input_frames = torch.cat([condition_frames_first, padding_frames, condition_frames_last], dim=2) + else: + condition_frames = input_frames[:, :, -num_frames_condition:] + padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) + encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) + + log.debug( + f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" + ) + if hasattr(model, "n_views"): + encode_input_frames = einops.rearrange(encode_input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views) + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + latent1 = model.encode(encode_input_frames[:, :, :num_frames_encode]) # BCTHW + latent2 = model.encode(encode_input_frames[:, :, num_frames_encode:]) + latent = torch.cat([latent1, latent2], dim=2) # BCTHW + else: + latent = model.encode(encode_input_frames) + return latent, encode_input_frames + + +def compute_num_frames_condition(model: DiffusionV2WModel, num_of_latent_overlap: int, downsample_factor=8) -> int: + """This function computes the number of condition pixel frames given the number of latent frames to overlap. + Args: + model (ExtendDiffusionModel): video generation model + num_of_latent_overlap (int): number of latent frames to overlap + downsample_factor (int): downsample factor for temporal reduce + Returns: + int: number of condition frames in output space + """ + if getattr(model.tokenizer.video_vae, "is_casual", True): + # For casual model + num_frames_condition = ( + num_of_latent_overlap + // model.tokenizer.video_vae.latent_chunk_duration + * model.tokenizer.video_vae.pixel_chunk_duration + ) + if num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration == 1: + num_frames_condition += 1 + elif num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration > 1: + num_frames_condition += ( + 1 + (num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration - 1) * downsample_factor + ) + else: + num_frames_condition = num_of_latent_overlap * downsample_factor + + return num_frames_condition + + +def get_condition_latent( + model: DiffusionV2WModel, + input_image_or_video_path: str, + num_input_frames: int = 1, + state_shape: list[int] = None, + frame_index: int = 0, + frame_stride: int = 1, +): + """Get condition latent from input image/video file. + + Args: + model (DiffusionV2WModel): Video generation model + input_image_or_video_path (str): Path to conditioning image/video + num_input_frames (int): Number of input frames for video2world prediction + + Returns: + tuple: (condition_latent, input_frames) where: + - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] + - input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W] + """ + if state_shape is None: + state_shape = model.state_shape + assert num_input_frames > 0, "num_input_frames must be greater than 0" + + H, W = ( + state_shape[-2] * model.tokenizer.spatial_compression_factor, + state_shape[-1] * model.tokenizer.spatial_compression_factor, + ) + if type(input_image_or_video_path) == str: + input_path_format = input_image_or_video_path.split(".")[-1] + input_frames = read_video_or_image_into_frames_BCTHW( + input_image_or_video_path, + input_path_format=input_path_format, + H=H, + W=W, + ) + else: + input_frames = input_image_or_video_path + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + start_frame = frame_index * frame_stride + end_frame = (frame_index + 1) * frame_stride + curr_input_frames = torch.cat( + [input_frames[:, :, start_frame : start_frame + 1], input_frames[:, :, end_frame : end_frame + 1]], dim=2 + ).contiguous() # BCTHW + num_of_latent_condition = 1 + num_frames_condition = compute_num_frames_condition( + model, num_of_latent_condition, downsample_factor=model.tokenizer.temporal_compression_factor + ) + + condition_latent, _ = create_condition_latent_from_input_frames(model, curr_input_frames, num_frames_condition) + condition_latent = condition_latent.to(torch.bfloat16) + return condition_latent + + condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_input_frames) + condition_latent = condition_latent.to(torch.bfloat16) + + return condition_latent + + +def get_condition_latent_multiview( + model: DiffusionMultiviewV2WModel, + input_image_or_video_path: str, + num_input_frames: int = 1, + state_shape: list[int] = None, +): + """Get condition latent from input image/video file. This is the function for the multi-view model where each view has one latent condition frame. + + Args: + model (DiffusionMultiviewV2WModel): Video generation model + input_image_or_video_path (str): Path to conditioning image/video + num_input_frames (int): Number of input frames for video2world prediction + + Returns: + tuple: (condition_latent, input_frames) where: + - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] + - input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W] + """ + if state_shape is None: + state_shape = model.state_shape + assert num_input_frames > 0, "num_input_frames must be greater than 0" + + H, W = ( + state_shape[-2] * model.tokenizer.spatial_compression_factor, + state_shape[-1] * model.tokenizer.spatial_compression_factor, + ) + input_path_format = input_image_or_video_path.split(".")[-1] + input_frames = read_video_or_image_into_frames_BCTHW( + input_image_or_video_path, + input_path_format=input_path_format, + H=H, + W=W, + ) + input_frames = einops.rearrange(input_frames, "B C (V T) H W -> (B V) C T H W", V=model.n_views) + condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_input_frames) + condition_latent = condition_latent.to(torch.bfloat16) + + return condition_latent, einops.rearrange(input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views)[0] + + +def check_input_frames(input_path: str, required_frames: int) -> bool: + """Check if input video/image has sufficient frames. + + Args: + input_path: Path to input video or image + required_frames: Number of required frames + + Returns: + bool: True if input has sufficient frames, False otherwise + """ + if input_path.endswith((".jpg", ".jpeg", ".png")): + if required_frames > 1: + log.error(f"Input ({input_path}) is an image but {required_frames} frames are required") + return False + return True # Let the pipeline handle image loading + # For video input + try: + vid = imageio.get_reader(input_path, "ffmpeg") + frame_count = vid.count_frames() + + if frame_count < required_frames: + log.error(f"Input video has {frame_count} frames but {required_frames} frames are required") + return False + else: + return True + except Exception as e: + log.error(f"Error reading video file {input_path}: {e}") + return False + + +def get_input_sizes(input_path: str) -> tuple[int, int]: + """Get the height and width of input video or image. + + Args: + input_path: Path to input video or image file + + Returns: + tuple: (height, width) dimensions of the input + """ + if input_path.endswith((".jpg", ".jpeg", ".png")): + # For image input + try: + img = imageio.imread(input_path) + return img.shape[0], img.shape[1] + except Exception as e: + log.error(f"Error reading image file {input_path}: {e}") + raise + else: + # For video input + try: + vid = imageio.get_reader(input_path, "ffmpeg") + first_frame = vid.get_data(0) + return first_frame.shape[0], first_frame.shape[1] + except Exception as e: + log.error(f"Error reading video file {input_path}: {e}") + raise diff --git a/cosmos_predict1/diffusion/inference/ray_triangle_intersection_warp.py b/cosmos_predict1/diffusion/inference/ray_triangle_intersection_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..e11589b0a9a3395db4b14293f3461a5572035d49 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/ray_triangle_intersection_warp.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import warp as wp +import numpy as np + +# Initialize Warp with CUDA +wp.init() + +@wp.kernel +def ray_triangle_intersection_kernel( + ray_origins: wp.array2d(dtype=wp.float32), # (H*W, 3) + ray_directions: wp.array2d(dtype=wp.float32), # (H*W, 3) + vertices: wp.array2d(dtype=wp.float32), # (N, 3) + faces: wp.array2d(dtype=wp.int32), # (M, 3) + depth_map: wp.array(dtype=wp.float32), # (H*W,) + num_triangles: wp.int32, + epsilon: wp.float32 +): + """ + Warp kernel for ray-triangle intersection using Möller–Trumbore algorithm. + Each thread processes one ray against all triangles. + """ + # Get thread index (ray index) + ray_idx = wp.tid() + + # Get ray origin and direction + ray_origin = wp.vec3( + ray_origins[ray_idx, 0], + ray_origins[ray_idx, 1], + ray_origins[ray_idx, 2] + ) + ray_dir = wp.vec3( + ray_directions[ray_idx, 0], + ray_directions[ray_idx, 1], + ray_directions[ray_idx, 2] + ) + + # Initialize minimum distance + min_t = wp.float32(1e10) + + # Iterate through all triangles + for tri_idx in range(num_triangles): + # Get triangle vertex indices + i0 = faces[tri_idx, 0] + i1 = faces[tri_idx, 1] + i2 = faces[tri_idx, 2] + + # Get triangle vertices + v0 = wp.vec3(vertices[i0, 0], vertices[i0, 1], vertices[i0, 2]) + v1 = wp.vec3(vertices[i1, 0], vertices[i1, 1], vertices[i1, 2]) + v2 = wp.vec3(vertices[i2, 0], vertices[i2, 1], vertices[i2, 2]) + + # Compute edges + edge1 = v1 - v0 + edge2 = v2 - v0 + + # Möller–Trumbore algorithm + h = wp.cross(ray_dir, edge2) + a = wp.dot(edge1, h) + + # Check if ray is parallel to triangle + if wp.abs(a) < epsilon: + continue + + f = 1.0 / a + s = ray_origin - v0 + u = f * wp.dot(s, h) + + # Check if intersection is within triangle (u >= 0 and u <= 1) + if u < 0.0 or u > 1.0: + continue + + q = wp.cross(s, edge1) + v = f * wp.dot(ray_dir, q) + + # Check if intersection is within triangle (v >= 0 and u + v <= 1) + if v < 0.0 or (u + v) > 1.0: + continue + + # Compute t (distance along ray) + t = f * wp.dot(edge2, q) + + # Only consider intersections in front of camera (t > 0) + if t > epsilon and t < min_t: + min_t = t + + # Write result + if min_t < 1e10: + depth_map[ray_idx] = min_t + else: + depth_map[ray_idx] = 0.0 + + +@wp.kernel +def ray_triangle_intersection_tiled_kernel( + ray_origins: wp.array2d(dtype=wp.float32), # (H*W, 3) + ray_directions: wp.array2d(dtype=wp.float32), # (H*W, 3) + vertices: wp.array2d(dtype=wp.float32), # (N, 3) + faces: wp.array2d(dtype=wp.int32), # (M, 3) + depth_map: wp.array(dtype=wp.float32), # (H*W,) + tri_start: wp.int32, # Start triangle index for this tile + tri_end: wp.int32, # End triangle index for this tile + epsilon: wp.float32 +): + """ + Tiled version of ray-triangle intersection kernel. + Processes a subset of triangles to improve memory access patterns. + """ + # Get thread index (ray index) + ray_idx = wp.tid() + + # Get ray origin and direction + ray_origin = wp.vec3( + ray_origins[ray_idx, 0], + ray_origins[ray_idx, 1], + ray_origins[ray_idx, 2] + ) + ray_dir = wp.vec3( + ray_directions[ray_idx, 0], + ray_directions[ray_idx, 1], + ray_directions[ray_idx, 2] + ) + + # Get current minimum distance + min_t = depth_map[ray_idx] + if min_t == 0.0: + min_t = wp.float32(1e10) + + # Process triangles in this tile + for tri_idx in range(tri_start, tri_end): + # Get triangle vertex indices + i0 = faces[tri_idx, 0] + i1 = faces[tri_idx, 1] + i2 = faces[tri_idx, 2] + + # Get triangle vertices + v0 = wp.vec3(vertices[i0, 0], vertices[i0, 1], vertices[i0, 2]) + v1 = wp.vec3(vertices[i1, 0], vertices[i1, 1], vertices[i1, 2]) + v2 = wp.vec3(vertices[i2, 0], vertices[i2, 1], vertices[i2, 2]) + + # Compute edges + edge1 = v1 - v0 + edge2 = v2 - v0 + + # Möller–Trumbore algorithm + h = wp.cross(ray_dir, edge2) + a = wp.dot(edge1, h) + + # Check if ray is parallel to triangle + if wp.abs(a) < epsilon: + continue + + f = 1.0 / a + s = ray_origin - v0 + u = f * wp.dot(s, h) + + # Check if intersection is within triangle (u >= 0 and u <= 1) + if u < 0.0 or u > 1.0: + continue + + q = wp.cross(s, edge1) + v = f * wp.dot(ray_dir, q) + + # Check if intersection is within triangle (v >= 0 and u + v <= 1) + if v < 0.0 or (u + v) > 1.0: + continue + + # Compute t (distance along ray) + t = f * wp.dot(edge2, q) + + # Only consider intersections in front of camera (t > 0) + if t > epsilon and t < min_t: + min_t = t + + # Write result using atomic min to handle concurrent updates + if min_t < 1e10: + wp.atomic_min(depth_map, ray_idx, min_t) + + +def ray_triangle_intersection_warp( + ray_origins: torch.Tensor, # (H, W, 3) + ray_directions: torch.Tensor, # (H, W, 3) + vertices: torch.Tensor, # (N, 3) + faces: torch.Tensor, # (M, 3) + device: torch.device +) -> torch.Tensor: + """ + Compute ray-triangle intersections using NVIDIA Warp for maximum GPU acceleration. + + This implementation uses Warp kernels to achieve the best possible performance + on NVIDIA GPUs by: + 1. Using native CUDA kernels through Warp + 2. Tiling triangles for better memory access patterns + 3. Using atomic operations for concurrent updates + 4. Minimizing memory transfers + + Args: + ray_origins: (H, W, 3) ray origins in camera space + ray_directions: (H, W, 3) ray directions (should be normalized) + vertices: (N, 3) mesh vertices + faces: (M, 3) triangle face indices + device: torch device (must be CUDA) + + Returns: + depth_map: (H, W) depth values, 0 where no intersection + """ + H, W = ray_origins.shape[:2] + num_rays = H * W + num_triangles = faces.shape[0] + + # Reshape rays to 2D arrays + ray_origins_flat = ray_origins.reshape(-1, 3).contiguous() + ray_directions_flat = ray_directions.reshape(-1, 3).contiguous() + + # Convert PyTorch tensors to Warp arrays (as float arrays, not vec3) + wp_ray_origins = wp.from_torch(ray_origins_flat, dtype=wp.float32) + wp_ray_directions = wp.from_torch(ray_directions_flat, dtype=wp.float32) + wp_vertices = wp.from_torch(vertices.contiguous(), dtype=wp.float32) + wp_faces = wp.from_torch(faces.int().contiguous(), dtype=wp.int32) + + # Create output depth map + depth_map_flat = torch.zeros(num_rays, device=device, dtype=torch.float32) + wp_depth_map = wp.from_torch(depth_map_flat, dtype=wp.float32) + + # Choose implementation based on problem size + if num_triangles < 10000: + # For smaller meshes, use simple kernel + wp.launch( + kernel=ray_triangle_intersection_kernel, + dim=num_rays, + inputs=[ + wp_ray_origins, + wp_ray_directions, + wp_vertices, + wp_faces, + wp_depth_map, + num_triangles, + 1e-8 # epsilon + ], + device=f"cuda:{device.index}" if device.index is not None else "cuda:0" + ) + else: + # For larger meshes, use tiled approach for better memory access + triangle_tile_size = 10000 # Process triangles in tiles + + # Initialize depth map to infinity + depth_map_flat.fill_(float('inf')) + + # Process triangles in tiles + for tri_start in range(0, num_triangles, triangle_tile_size): + tri_end = min(tri_start + triangle_tile_size, num_triangles) + + wp.launch( + kernel=ray_triangle_intersection_tiled_kernel, + dim=num_rays, + inputs=[ + wp_ray_origins, + wp_ray_directions, + wp_vertices, + wp_faces, + wp_depth_map, + tri_start, + tri_end, + 1e-8 # epsilon + ], + device=f"cuda:{device.index}" if device.index is not None else "cuda:0" + ) + + # Convert infinity back to 0 + depth_map_flat[depth_map_flat == float('inf')] = 0.0 + + # Synchronize to ensure kernel completion + wp.synchronize() + + # Reshape back to 2D + depth_map = depth_map_flat.reshape(H, W) + + return depth_map diff --git a/cosmos_predict1/diffusion/inference/text2world.py b/cosmos_predict1/diffusion/inference/text2world.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac81213e0fea581f69de3bf5008820b718e45c0 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/text2world.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import add_common_arguments, validate_args +from cosmos_predict1.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video + +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Text to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + # Add text2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-Predict1-7B-Text2World", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-Predict1-7B-Text2World", + "Cosmos-Predict1-14B-Text2World", + "Cosmos-Predict1-7B-Text2World_post-trained", + "Cosmos-Predict1-7B-Text2World_post-trained-4gpu_80gb", + "Cosmos-Predict1-7B-Text2World_post-trained-8gpu_40gb", + "Cosmos-Predict1-7B-Text2World_post-trained-4gpu_40gb", + "Cosmos-Predict1-7B-Text2World_post-trained-lora", + "Cosmos-Predict1-14B-Text2World_post-trained", + ], + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Cosmos-UpsamplePrompt1-12B-Text2World", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + + parser.add_argument( + "--word_limit_to_skip_upsampler", + type=int, + default=250, + help="Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value", + ) + + return parser.parse_args() + + +def demo(args): + """Run text-to-world generation demo. + + This function handles the main text-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts from input + - Generating videos from text prompts + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "text2world" + validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize text2world generation model pipeline + pipeline = DiffusionText2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.diffusion_transformer_dir, + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=args.num_video_frames, + seed=args.seed, + ) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": args.prompt}] + + os.makedirs(args.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None: + log.critical("Prompt is missing, skipping world generation.") + continue + + # Generate video + generated_output = pipeline.generate(current_prompt, args.negative_prompt, args.word_limit_to_skip_upsampler) + if generated_output is None: + log.critical("Guardrail blocked text2world generation.") + continue + video, prompt = generated_output + + if args.batch_input_path: + video_save_path = os.path.join(args.video_save_folder, f"{i}.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=5, + video_save_path=video_save_path, + ) + + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + f.write(prompt.encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/cosmos_predict1/diffusion/inference/text2world_multiview.py b/cosmos_predict1/diffusion/inference/text2world_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..717a7a9a4e267cc49d5ec9df31798f3121cff67f --- /dev/null +++ b/cosmos_predict1/diffusion/inference/text2world_multiview.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import add_common_arguments, remove_argument, validate_args +from cosmos_predict1.diffusion.inference.world_generation_pipeline import DiffusionText2WorldMultiviewGenerationPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video + +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Text to world generation demo script") + # Add common arguments + add_common_arguments(parser) + remove_argument(parser, "width") + remove_argument(parser, "height") + remove_argument(parser, "num_video_frames") + parser.add_argument("--height", type=int, default=480, help="Height of video to sample") + parser.add_argument("--width", type=int, default=848, help="Width of video to sample") + parser.add_argument( + "--num_video_frames", + type=int, + default=57, + choices=[57], + help="Number of video frames to sample, this is per-camera frame number.", + ) + # Add text2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview", + ], + ) + parser.add_argument( + "--prompt_left", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing to the left. ", + help="Text prompt for generating left camera view video", + ) + parser.add_argument( + "--prompt_right", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing to the right.", + help="Text prompt for generating right camera view video", + ) + parser.add_argument( + "--prompt_back", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing backwards.", + help="Text prompt for generating rear camera view video", + ) + parser.add_argument( + "--prompt_back_left", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing the rear left side.", + help="Text prompt for generating left camera view video", + ) + parser.add_argument( + "--prompt_back_right", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing the rear right side.", + help="Text prompt for generating right camera view video", + ) + parser.add_argument( + "--frame_repeat_negative_condition", + type=float, + default=10.0, + help="frame_repeat number to be used as negative condition", + ) + + return parser.parse_args() + + +def demo(args): + """Run multi-view text-to-world generation demo. + + This function handles the main text-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts from input + - Generating videos from text prompts + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "text2world" + validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize text2world generation model pipeline + pipeline = DiffusionText2WorldMultiviewGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.diffusion_transformer_dir, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=args.num_video_frames, + frame_repeat_negative_condition=args.frame_repeat_negative_condition, + seed=args.seed, + ) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [ + { + "prompt": args.prompt, + "prompt_left": args.prompt_left, + "prompt_right": args.prompt_right, + "prompt_back": args.prompt_back, + "prompt_back_left": args.prompt_back_left, + "prompt_back_right": args.prompt_back_right, + } + ] + + os.makedirs(args.video_save_folder, exist_ok=True) + for i, current_prompt in enumerate(prompts): + # Generate video + generated_output = pipeline.generate(current_prompt) + if generated_output is None: + log.critical("Guardrail blocked text2world generation.") + continue + [video_grid, video], prompt = generated_output + + if args.batch_input_path: + video_save_path = os.path.join(args.video_save_folder, f"{i}.mp4") + video_grid_save_path = os.path.join(args.video_save_folder, f"{i}_grid.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + video_grid_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}_grid.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=10, + video_save_path=video_save_path, + ) + + save_video( + video=video_grid, + fps=args.fps, + H=args.height * 2, + W=args.width * 3, + video_save_quality=5, + video_save_path=video_grid_save_path, + ) + + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + for key, value in prompt.items(): + f.write(value.encode("utf-8")) + f.write("\n".encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/cosmos_predict1/diffusion/inference/video2world.py b/cosmos_predict1/diffusion/inference/video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..04acd69075831641a026174674d801299cd26191 --- /dev/null +++ b/cosmos_predict1/diffusion/inference/video2world.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import ( + add_common_arguments, + check_input_frames, + get_input_sizes, + validate_args, +) +from cosmos_predict1.diffusion.inference.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video + +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + # Add video2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-Predict1-7B-Video2World", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-Predict1-7B-Video2World", + "Cosmos-Predict1-14B-Video2World", + "Cosmos-Predict1-7B-Video2World_post-trained", + "Cosmos-Predict1-7B-Video2World_post-trained-4gpu_80gb", + "Cosmos-Predict1-7B-Video2World_post-trained-8gpu_40gb", + "Cosmos-Predict1-7B-Video2World_post-trained-4gpu_40gb", + "Cosmos-Predict1-7B-Video2World_post-trained-lora", + "Cosmos-Predict1-14B-Video2World_post-trained", + ], + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + parser.add_argument( + "--input_image_or_video_path", + type=str, + help="Input video/image path for generating a single video", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=1, + help="Number of input frames for video2world prediction", + choices=[1, 9], + ) + + return parser.parse_args() + + +def demo(args): + """Run video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "video2world" + validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize video2world generation model pipeline + pipeline = DiffusionVideo2WorldGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.diffusion_transformer_dir, + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=args.num_video_frames, + seed=args.seed, + num_input_frames=args.num_input_frames, + ) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": args.prompt, "visual_input": args.input_image_or_video_path}] + + os.makedirs(args.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None and args.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + continue + current_image_or_video_path = input_dict.get("visual_input", None) + if current_image_or_video_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + + # Check input frames + if not check_input_frames(current_image_or_video_path, args.num_input_frames): + continue + log.warning("Visual input is provided, overriding --height and --width arguments.") + args.height, args.width = get_input_sizes(current_image_or_video_path) + + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_or_video_path=current_image_or_video_path, + negative_prompt=args.negative_prompt, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + video, prompt = generated_output + + if args.batch_input_path: + video_save_path = os.path.join(args.video_save_folder, f"{i}.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=5, + video_save_path=video_save_path, + ) + + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + f.write(prompt.encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/cosmos_predict1/diffusion/inference/video2world_multiview.py b/cosmos_predict1/diffusion/inference/video2world_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..9bca0721e8a0c57476e73206dc6838aaaab3234b --- /dev/null +++ b/cosmos_predict1/diffusion/inference/video2world_multiview.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import ( + add_common_arguments, + check_input_frames, + get_input_sizes, + remove_argument, + validate_args, +) +from cosmos_predict1.diffusion.inference.world_generation_pipeline import ( + DiffusionVideo2WorldMultiviewGenerationPipeline, +) +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video + +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + remove_argument(parser, "width") + remove_argument(parser, "height") + remove_argument(parser, "num_video_frames") + parser.add_argument("--height", type=int, default=480, help="Height of video to sample") + parser.add_argument("--width", type=int, default=848, help="Width of video to sample") + + parser.add_argument( + "--num_video_frames", + type=int, + default=57, + choices=[57], + help="Number of video frames to sample, this is per-camera frame number.", + ) + # Add video2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview", + ], + ) + parser.add_argument( + "--prompt_left", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing to the left. ", + help="Text prompt for generating left camera view video", + ) + parser.add_argument( + "--prompt_right", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing to the right.", + help="Text prompt for generating right camera view video", + ) + + parser.add_argument( + "--prompt_back", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing backwards.", + help="Text prompt for generating rear camera view video", + ) + parser.add_argument( + "--prompt_back_left", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing the rear left side.", + help="Text prompt for generating left camera view video", + ) + parser.add_argument( + "--prompt_back_right", + type=str, + default="The video is captured from a camera mounted on a car. The camera is facing the rear right side.", + help="Text prompt for generating right camera view video", + ) + parser.add_argument( + "--frame_repeat_negative_condition", + type=float, + default=10.0, + help="frame_repeat number to be used as negative condition", + ) + parser.add_argument( + "--input_image_or_video_path", + type=str, + help="Input video/image path for generating a single video", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=1, + help="Number of input frames for video2world prediction", + choices=[1, 9], + ) + + return parser.parse_args() + + +def demo(args): + """Run multi-view video-to-world generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + misc.set_random_seed(args.seed) + inference_type = "video2world" + validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize video2world generation model pipeline + pipeline = DiffusionVideo2WorldMultiviewGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.diffusion_transformer_dir, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + guidance=args.guidance, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=args.num_video_frames, + frame_repeat_negative_condition=args.frame_repeat_negative_condition, + seed=args.seed, + num_input_frames=args.num_input_frames, + ) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [ + { + "prompt": args.prompt, + "prompt_left": args.prompt_left, + "prompt_right": args.prompt_right, + "prompt_back": args.prompt_back, + "prompt_back_left": args.prompt_back_left, + "prompt_back_right": args.prompt_back_right, + "visual_input": args.input_image_or_video_path, + } + ] + + os.makedirs(args.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_image_or_video_path = input_dict.pop("visual_input", None) + if current_image_or_video_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + current_prompt = input_dict + + # Check input frames + if not check_input_frames(current_image_or_video_path, args.num_input_frames): + continue + log.warning("Visual input is provided, overriding --height and --width arguments.") + args.height, args.width = get_input_sizes(current_image_or_video_path) + + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_or_video_path=current_image_or_video_path, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + [video_grid, video], prompt = generated_output + + if args.batch_input_path: + video_save_path = os.path.join(args.video_save_folder, f"{i}.mp4") + video_grid_save_path = os.path.join(args.video_save_folder, f"{i}_grid.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{i}.txt") + else: + video_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") + video_grid_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}_grid.mp4") + prompt_save_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.txt") + + # Save video + save_video( + video=video, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=10, + video_save_path=video_save_path, + ) + save_video( + video=video_grid, + fps=args.fps, + H=args.height * 2, + W=args.width * 3, + video_save_quality=5, + video_save_path=video_grid_save_path, + ) + # Save prompt to text file alongside video + with open(prompt_save_path, "wb") as f: + for key, value in prompt.items(): + f.write(value.encode("utf-8")) + f.write("\n".encode("utf-8")) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/cosmos_predict1/diffusion/inference/world_generation_pipeline.py b/cosmos_predict1/diffusion/inference/world_generation_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..422b542059b4ebd4b08bf17b197cec604bc51ffa --- /dev/null +++ b/cosmos_predict1/diffusion/inference/world_generation_pipeline.py @@ -0,0 +1,1464 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from typing import Any, Optional + +import einops +import numpy as np +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import ( + generate_world_from_text, + generate_world_from_video, + get_condition_latent, + get_condition_latent_multiview, + get_video_batch, + get_video_batch_for_multiview_model, + load_model_by_config, + load_network_model, + load_tokenizer_model, + read_video_or_image_into_frames_BCTHW, +) +from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel +from cosmos_predict1.diffusion.model.model_t2w_multiview import DiffusionMultiviewT2WModel +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel +from cosmos_predict1.diffusion.model.model_v2w_multiview import DiffusionMultiviewV2WModel +from cosmos_predict1.diffusion.model.model_world_interpolator import DiffusionWorldInterpolatorWModel +from cosmos_predict1.diffusion.prompt_upsampler.text2world_prompt_upsampler_inference import ( + create_prompt_upsampler, + run_chat_completion, +) +from cosmos_predict1.diffusion.prompt_upsampler.video2world_prompt_upsampler_inference import ( + create_vlm_prompt_upsampler, + prepare_dialog, +) +from cosmos_predict1.diffusion.prompt_upsampler.video2world_prompt_upsampler_inference import ( + run_chat_completion as run_chat_completion_vlm, +) +from cosmos_predict1.diffusion.training.utils.inference_long_video import generate_video_from_batch_with_loop +from cosmos_predict1.utils import log +from cosmos_predict1.utils.base_world_generation_pipeline import BaseWorldGenerationPipeline + +MODEL_NAME_DICT = { + # text2world + "Cosmos-Predict1-7B-Text2World": "Cosmos_Predict1_Text2World_7B", + "Cosmos-Predict1-14B-Text2World": "Cosmos_Predict1_Text2World_14B", + "Cosmos-Predict1-7B-Text2World_post-trained": "Cosmos_Predict1_Text2World_7B_Post_trained", + "Cosmos-Predict1-14B-Text2World_post-trained": "Cosmos_Predict1_Text2World_14B_Post_trained", + # text2world low-memory + "Cosmos-Predict1-7B-Text2World_post-trained-4gpu_80gb": "Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_80gb", + "Cosmos-Predict1-7B-Text2World_post-trained-8gpu_40gb": "Cosmos_Predict1_Text2World_7B_Post_trained_8gpu_40gb", + "Cosmos-Predict1-7B-Text2World_post-trained-4gpu_40gb": "Cosmos_Predict1_Text2World_7B_Post_trained_4gpu_40gb", + # text2world lora + "Cosmos-Predict1-7B-Text2World_post-trained-lora": "Cosmos_Predict1_Text2World_7B_Post_trained_lora", + # video2world + "Cosmos-Predict1-7B-Video2World": "Cosmos_Predict1_Video2World_7B", + "Cosmos-Predict1-14B-Video2World": "Cosmos_Predict1_Video2World_14B", + "Cosmos-Predict1-7B-Video2World_post-trained": "Cosmos_Predict1_Video2World_7B_Post_trained", + "Cosmos-Predict1-14B-Video2World_post-trained": "Cosmos_Predict1_Video2World_14B_Post_trained", + # video2world low-memory + "Cosmos-Predict1-7B-Video2World_post-trained-4gpu_80gb": "Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_80gb", + "Cosmos-Predict1-7B-Video2World_post-trained-8gpu_40gb": "Cosmos_Predict1_Video2World_7B_Post_trained_8gpu_40gb", + "Cosmos-Predict1-7B-Video2World_post-trained-4gpu_40gb": "Cosmos_Predict1_Video2World_7B_Post_trained_4gpu_40gb", + # video2world lora + "Cosmos-Predict1-7B-Video2World_post-trained-lora": "Cosmos_Predict1_Video2World_7B_Post_trained_lora", + "Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview": "Cosmos_Predict1_Text2World_7B_Multiview", + "Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview": "Cosmos_Predict1_Video2World_7B_Multiview", + "Cosmos-Predict1-7B-WorldInterpolator": "Cosmos_Predict1_WorldInterpolator_7B", + # Gen3c + "Gen3C-Cosmos-7B": "GEN3C_Cosmos_7B", +} + + +class DiffusionText2WorldGenerationPipeline(BaseWorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 0, + ): + """Initialize the diffusion world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + """ + assert inference_type in [ + "text2world", + "video2world", + "world_interpolator", + ], "Invalid inference_type, must be 'text2world' or 'video2world'" + + self.model_name = MODEL_NAME_DICT[checkpoint_name] + self.guidance = guidance + self.num_steps = num_steps + self.height = height + self.width = width + self.fps = fps + self.num_video_frames = num_video_frames + self.seed = seed + + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + ) + self.prompt_upsampler_dir = prompt_upsampler_dir + self.enable_prompt_upsampler = enable_prompt_upsampler + self.offload_prompt_upsampler = offload_prompt_upsampler + + self.prompt_upsampler = None + if enable_prompt_upsampler and not offload_prompt_upsampler: + self._load_prompt_upsampler_model() + + def _load_prompt_upsampler_model(self): + self.prompt_upsampler = create_prompt_upsampler( + checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir), + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionT2WModel, + ) + + def _load_network(self): + load_network_model(self.model, f"{self.checkpoint_dir}/{self.checkpoint_name}/model.pt") + + def _load_tokenizer(self): + load_tokenizer_model(self.model, f"{self.checkpoint_dir}/Cosmos-Tokenize1-CV8x8x8-720p") + + def _offload_prompt_upsampler_model(self): + """Move prompt enhancement model to CPU/disk. + + Offloads prompt upsampling model after processing input + to reduce GPU memory usage. + """ + if self.prompt_upsampler: + del self.prompt_upsampler + self.prompt_upsampler = None + gc.collect() + torch.cuda.empty_cache() + + def _run_prompt_upsampler_on_prompt(self, prompt: str) -> str: + """Enhance the input prompt using the prompt upsampler model. + + Args: + prompt: Raw text prompt to be enhanced + + Returns: + str: Enhanced version of the input prompt with more descriptive details + """ + upsampled_prompt = run_chat_completion(self.prompt_upsampler, prompt) + log.info(f"Upsampled prompt: {upsampled_prompt}") + return upsampled_prompt + + def _run_prompt_upsampler_on_prompt_with_offload(self, *args: Any, **kwargs: Any) -> str: + """Enhance prompt with prompt upsampler model. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Enhanced prompt string + """ + if self.offload_prompt_upsampler: + self._load_prompt_upsampler_model() + + enhanced_prompt = self._run_prompt_upsampler_on_prompt(*args, **kwargs) + + if self.offload_prompt_upsampler: + self._offload_prompt_upsampler_model() + + return enhanced_prompt + + def _run_tokenizer_decoding(self, sample: torch.Tensor) -> np.ndarray: + """Decode latent samples to video frames using the tokenizer decoder. + + Args: + sample: Latent tensor from diffusion model [B, C, T, H, W] + + Returns: + np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C] + with values in range [0, 255] + """ + # Decode video + video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W] + video = (video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() + + return video + + def _run_model( + self, + embedding: torch.Tensor, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Generate video latents using the diffusion model. + + Args: + embedding: Text embedding tensor from text encoder + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + torch.Tensor: Generated video latents before tokenizer decoding + + Note: + The model and tokenizer are automatically offloaded after inference + if offloading is enabled in the config. + """ + # Get video batch and state shape + data_batch, state_shape = get_video_batch( + model=self.model, + prompt_embedding=embedding, + negative_prompt_embedding=negative_prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames, + ) + + # Generate video frames + sample = generate_world_from_text( + model=self.model, + state_shape=state_shape, + is_negative_prompt=True if negative_prompt_embedding is not None else False, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + ) + + return sample + + def _run_model_with_offload( + self, prompt_embedding: torch.Tensor, negative_prompt_embedding: Optional[torch.Tensor] = None + ) -> np.ndarray: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from text encoder + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + np.ndarray: Generated world representation + """ + if self.offload_network: + self._load_network() + + if self.offload_tokenizer: + self._load_tokenizer() + + sample = self._run_model(prompt_embedding, negative_prompt_embedding) + + if self.offload_network: + self._offload_network() + + if self.offload_tokenizer: + self._load_tokenizer() + + sample = self._run_tokenizer_decoding(sample) + + if self.offload_tokenizer: + self._offload_tokenizer() + return sample + + def generate( + self, + prompt: str, + negative_prompt: Optional[str] = None, + word_limit_to_skip_upsampler: Optional[int] = None, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt with optional negative prompt guidance. + + Pipeline steps: + 1. Run safety checks on input prompt + 2. Enhance prompt using upsampler if enabled + 3. Run safety checks on upsampled prompt if applicable + 4. Convert prompt to embeddings + 5. Generate video frames using diffusion + 6. Run safety checks and apply face blur on generated video frames + + Args: + prompt: Text description of desired video + negative_prompt: Optional text to guide what not to generate + word_limit_to_skip_upsampler: Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + log.info(f"Run with prompt: {prompt}") + log.info(f"Run with negative prompt: {negative_prompt}") + log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}") + + if not self.disable_guardrail: + log.info("Run guardrail on prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical("Input text prompt is not safe") + return None + log.info("Pass guardrail on prompt") + + # Enhance prompt + if self.enable_prompt_upsampler: + word_count = len(prompt.split()) + if word_limit_to_skip_upsampler is None or word_count <= word_limit_to_skip_upsampler: + log.info("Run prompt upsampler on prompt") + prompt = self._run_prompt_upsampler_on_prompt_with_offload(prompt) + if not self.disable_guardrail: + log.info("Run guardrail on upsampled prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt=prompt) + if not is_safe: + log.critical("Upsampled text prompt is not safe") + return None + log.info("Pass guardrail on upsampled prompt") + else: + log.info( + f"Skip prompt upsampler for better robustness because the number of words ({word_count}) in the prompt is greater than {word_limit_to_skip_upsampler}" + ) + + log.info("Run text embedding on prompt") + if negative_prompt: + prompts = [prompt, negative_prompt] + else: + prompts = [prompt] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + prompt_embedding = prompt_embeddings[0] + negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + ) + log.info("Finish generation") + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + video = self._run_guardrail_on_video_with_offload(video) + if video is None: + log.critical("Generated video is not safe") + return None + log.info("Pass guardrail on generated video") + + return video, prompt + + +class DiffusionVideo2WorldGenerationPipeline(DiffusionText2WorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 0, + num_input_frames: int = 1, + ): + """Initialize diffusion world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + num_input_frames: Number of latent conditions + """ + self.num_input_frames = num_input_frames + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + enable_prompt_upsampler=enable_prompt_upsampler, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + ) + + def _run_prompt_upsampler_on_prompt(self, image_or_video_path: str) -> str: + """Enhance the input prompt using visual context from the conditioning image. + + Args: + image_or_video_path: Path to conditioning image or video used for visual context + + Returns: + str: Enhanced prompt incorporating visual details from the image + """ + dialog = prepare_dialog(image_or_video_path) + upsampled_prompt = run_chat_completion_vlm( + self.prompt_upsampler, dialog, max_gen_len=400, temperature=0.01, top_p=0.9, logprobs=False + ) + log.info(f"Upsampled prompt: {upsampled_prompt}") + return upsampled_prompt + + def _load_prompt_upsampler_model(self): + self.prompt_upsampler = create_vlm_prompt_upsampler( + checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir), + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionV2WModel, + ) + + def _run_model( + self, + embedding: torch.Tensor, + condition_latent: torch.Tensor, + negative_prompt_embedding: torch.Tensor | None = None, + ) -> torch.Tensor: + """Generate video frames using the diffusion model. + + Args: + embedding: Text embedding tensor from T5 encoder + condition_latent: Latent tensor from conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + Tensor of generated video frames + + Note: + Model and tokenizer are automatically offloaded after inference + if offloading is enabled. + """ + # Get video batch and state shape + data_batch, state_shape = get_video_batch( + model=self.model, + prompt_embedding=embedding, + negative_prompt_embedding=negative_prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames, + ) + + # Generate video frames + video = generate_world_from_video( + model=self.model, + state_shape=self.model.state_shape, + is_negative_prompt=True, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + condition_latent=condition_latent, + num_input_frames=self.num_input_frames, + ) + + return video + + def _run_tokenizer_encoding(self, image_or_video_path: str) -> torch.Tensor: + """ + Encode image to latent space + + Args: + image_or_video_path: Path to conditioning image + + Returns: + torch.Tensor: Latent tensor from tokenizer encoding + """ + condition_latent = get_condition_latent( + model=self.model, + input_image_or_video_path=image_or_video_path, + num_input_frames=self.num_input_frames, + state_shape=self.model.state_shape, + ) + + return condition_latent + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + image_or_video_path: str, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> np.ndarray: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from T5 encoder + image_or_video_path: Path to conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + np.ndarray: Generated world representation as numpy array + """ + if self.offload_tokenizer: + self._load_tokenizer() + + condition_latent = self._run_tokenizer_encoding(image_or_video_path) + + if self.offload_network: + self._load_network() + + sample = self._run_model(prompt_embedding, condition_latent, negative_prompt_embedding) + + if self.offload_network: + self._offload_network() + + sample = self._run_tokenizer_decoding(sample) + + if self.offload_tokenizer: + self._offload_tokenizer() + + return sample + + def generate( + self, + prompt: str, + image_or_video_path: str, + negative_prompt: Optional[str] = None, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt and optional image. + + Pipeline steps: + 1. Run safety checks on input prompt + 2. Enhance prompt using upsampler if enabled + 3. Run safety checks on upsampled prompt if applicable + 4. Convert prompt to embeddings + 5. Generate video frames using diffusion + 6. Run safety checks and apply face blur on generated video frames + + Args: + prompt: Text description of desired video + image_or_video_path: Path to conditioning image or video + negative_prompt: Optional text to guide what not to generate + + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + + log.info(f"Run with image or video path: {image_or_video_path}") + log.info(f"Run with negative prompt: {negative_prompt}") + log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}") + + if self.enable_prompt_upsampler: + log.info("Run prompt upsampler on image or video, input prompt is not used") + prompt = self._run_prompt_upsampler_on_prompt_with_offload(image_or_video_path=image_or_video_path) + + log.info(f"Run with prompt: {prompt}") + if not self.disable_guardrail: + log.info(f"Run guardrail on {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical(f"Input {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt is not safe") + return None + log.info(f"Pass guardrail on {'upsampled' if self.enable_prompt_upsampler else 'text'} prompt") + else: + log.info("Not running guardrail") + + log.info("Run text embedding on prompt") + if negative_prompt: + prompts = [prompt, negative_prompt] + else: + prompts = [prompt] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + prompt_embedding = prompt_embeddings[0] + negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + image_or_video_path=image_or_video_path, + ) + log.info("Finish generation") + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + video = self._run_guardrail_on_video_with_offload(video) + if video is None: + log.critical("Generated video is not safe") + return None + log.info("Pass guardrail on generated video") + + return video, prompt + + +class DiffusionText2WorldMultiviewGenerationPipeline(DiffusionText2WorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + n_views: int = 6, + frame_repeat_negative_condition: int = 10, + seed: int = 0, + ): + """Initialize the diffusion multi-view world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + n_views: Number of views + frame_repeat_negative_condition: Number of frames to repeat to be used as negative condition. + seed: Random seed for sampling + """ + assert inference_type in [ + "text2world", + "video2world", + ], "Invalid inference_type, must be 'text2world' or 'video2world'" + + self.n_views = n_views + self.frame_repeat_negative_condition = frame_repeat_negative_condition + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + enable_prompt_upsampler=False, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionMultiviewT2WModel, + ) + + def _run_tokenizer_decoding(self, sample: torch.Tensor) -> np.ndarray: + """Decode latent samples to video frames using the tokenizer decoder. + + Args: + sample: Latent tensor from diffusion model [B, C, T, H, W] + + Returns: + np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C] + with values in range [0, 255] + """ + # Decode video + video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W] + video_segments = einops.rearrange(video, "b c (v t) h w -> b c v t h w", v=self.n_views) + grid_video = torch.stack( + [video_segments[:, :, i] for i in [1, 0, 2, 4, 3, 5]], + dim=2, + ) + grid_video = einops.rearrange(grid_video, "b c (h w) t h1 w1 -> b c t (h h1) (w w1)", h=2, w=3) + grid_video = (grid_video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() + video = (video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() + + return [grid_video, video] + + def _run_model( + self, + embedding: torch.Tensor, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Generate video latents using the diffusion model. + + Args: + embedding: Text embedding tensor from text encoder + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + torch.Tensor: Generated video latents before tokenizer decoding + + Note: + The model and tokenizer are automatically offloaded after inference + if offloading is enabled in the config. + """ + # Get video batch and state shape + data_batch, state_shape = get_video_batch_for_multiview_model( + model=self.model, + prompt_embedding=embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames * len(embedding), # number of views + frame_repeat_negative_condition=self.frame_repeat_negative_condition, + ) + + # Generate video frames + sample = generate_world_from_text( + model=self.model, + state_shape=state_shape, + is_negative_prompt=False, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + ) + + return sample + + def generate( + self, + prompt: dict, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt with optional negative prompt guidance. + + Pipeline steps: + 1. Convert prompt to embeddings + 2. Generate video frames using diffusion + + Args: + prompt: A dictionary of text description of desired video. + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + log.info(f"Run with prompt: {prompt}") + + prompts = [ + prompt["prompt"], + prompt["prompt_left"], + prompt["prompt_right"], + prompt["prompt_back"], + prompt["prompt_back_left"], + prompt["prompt_back_right"], + ] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + videos = self._run_model_with_offload( + prompt_embeddings, + ) + log.info("Finish generation") + + return videos, prompt + + +class DiffusionVideo2WorldMultiviewGenerationPipeline(DiffusionText2WorldMultiviewGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + guidance: float = 7.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 0, + num_input_frames: int = 1, + n_views: int = 6, + frame_repeat_negative_condition: int = 10, + ): + """Initialize diffusion world multi-view generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + num_input_frames: Number of latent conditions + """ + self.num_input_frames = num_input_frames + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + n_views=n_views, + frame_repeat_negative_condition=frame_repeat_negative_condition, + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionMultiviewV2WModel, + ) + + def _run_model( + self, + embedding: torch.Tensor, + condition_latent: torch.Tensor, + negative_prompt_embedding: torch.Tensor | None = None, + data_batch: dict = None, + state_shape: list = None, + ) -> torch.Tensor: + """Generate video frames using the diffusion model. + + Args: + embedding: Text embedding tensor from T5 encoder + condition_latent: Latent tensor from conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + Tensor of generated video frames + + Note: + Model and tokenizer are automatically offloaded after inference + if offloading is enabled. + """ + # Generate video frames + video = generate_world_from_video( + model=self.model, + state_shape=state_shape, + is_negative_prompt=False, + data_batch=data_batch, + guidance=self.guidance, + num_steps=self.num_steps, + seed=self.seed, + condition_latent=condition_latent, + num_input_frames=self.num_input_frames, + ) + + return video + + def _run_tokenizer_encoding(self, image_or_video_path: str, state_shape: list) -> torch.Tensor: + """ + Encode image to latent space + + Args: + image_or_video_path: Path to conditioning image + + Returns: + torch.Tensor: Latent tensor from tokenizer encoding + """ + condition_latent, condition_frames = get_condition_latent_multiview( + model=self.model, + input_image_or_video_path=image_or_video_path, + num_input_frames=self.num_input_frames, + state_shape=state_shape, + ) + + return condition_latent, condition_frames + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + image_or_video_path: str, + negative_prompt_embedding: Optional[torch.Tensor] = None, + ) -> np.ndarray: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from T5 encoder + image_or_video_path: Path to conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + + Returns: + np.ndarray: Generated world representation as numpy array + """ + if self.offload_tokenizer: + self._load_tokenizer() + + data_batch, state_shape = get_video_batch_for_multiview_model( + model=self.model, + prompt_embedding=prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames * len(prompt_embedding), # number of views + frame_repeat_negative_condition=self.frame_repeat_negative_condition, + ) + + condition_latent, condition_frames = self._run_tokenizer_encoding(image_or_video_path, state_shape) + + if self.offload_network: + self._load_network() + + sample = self._run_model(prompt_embedding, condition_latent, negative_prompt_embedding, data_batch, state_shape) + + if self.offload_network: + self._offload_network() + + sample = self._run_tokenizer_decoding(sample) + + if self.offload_tokenizer: + self._offload_tokenizer() + + return sample + + def generate( + self, + prompt: dict, + image_or_video_path: str, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt with optional negative prompt guidance. + + Pipeline steps: + 1. Convert prompt to embeddings + 2. Generate video frames using diffusion + + Args: + prompt: A dictionary of text description of desired video. + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + log.info(f"Run with prompt: {prompt}") + + prompts = [ + prompt["prompt"], + prompt["prompt_left"], + prompt["prompt_right"], + prompt["prompt_back"], + prompt["prompt_back_left"], + prompt["prompt_back_right"], + ] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embeddings, + image_or_video_path=image_or_video_path, + ) + log.info("Finish generation") + + return video, prompt + + +class DiffusionWorldInterpolatorGenerationPipeline(DiffusionVideo2WorldGenerationPipeline): + def __init__( + self, + inference_type: str, + checkpoint_dir: str, + checkpoint_name: str, + prompt_upsampler_dir: Optional[str] = None, + enable_prompt_upsampler: bool = True, + has_text_input: bool = True, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_prompt_upsampler: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + guidance: float = -1.0, + num_steps: int = 35, + height: int = 704, + width: int = 1280, + fps: int = 24, + num_video_frames: int = 121, + seed: int = 11, + num_input_frames: int = 1, + num_frame_pairs: int = 1, + frame_index_start: int = 0, + frame_stride: int = 1, + ): + """Initialize diffusion world generation pipeline. + + Args: + inference_type: Type of world generation ('text2world' or 'video2world') + checkpoint_dir: Base directory containing model checkpoints + checkpoint_name: Name of the diffusion transformer checkpoint to use + prompt_upsampler_dir: Directory containing prompt upsampler model weights + enable_prompt_upsampler: Whether to use prompt upsampling + has_text_input: Whether the pipeline takes text input for world generation + offload_network: Whether to offload diffusion transformer after inference + offload_tokenizer: Whether to offload tokenizer after inference + offload_text_encoder_model: Whether to offload T5 model after inference + offload_prompt_upsampler: Whether to offload prompt upsampler + offload_guardrail_models: Whether to offload guardrail models + disable_guardrail: Whether to disable guardrail + guidance: Classifier-free guidance scale + num_steps: Number of diffusion sampling steps + height: Height of output video + width: Width of output video + fps: Frames per second of output video + num_video_frames: Number of frames to generate + seed: Random seed for sampling + num_input_frames: Number of latent conditions + """ + self.num_input_frames = num_input_frames + self.num_frame_pairs = num_frame_pairs + self.frame_index_start = frame_index_start + self.frame_stride = frame_stride + self.num_steps = num_steps + self.height = height + self.width = width + self.fps = fps + + super().__init__( + inference_type=inference_type, + checkpoint_dir=checkpoint_dir, + checkpoint_name=checkpoint_name, + prompt_upsampler_dir=prompt_upsampler_dir, + enable_prompt_upsampler=enable_prompt_upsampler, + has_text_input=has_text_input, + offload_network=offload_network, + offload_tokenizer=offload_tokenizer, + offload_text_encoder_model=offload_text_encoder_model, + offload_prompt_upsampler=offload_prompt_upsampler, + offload_guardrail_models=offload_guardrail_models, + disable_guardrail=disable_guardrail, + guidance=guidance, + num_steps=num_steps, + height=height, + width=width, + fps=fps, + num_video_frames=num_video_frames, + seed=seed, + num_input_frames=num_input_frames, + ) + + def _run_prompt_upsampler_on_prompt(self, image_or_video_path: str) -> str: + """Enhance the input prompt using visual context from the conditioning image. + + Args: + image_or_video_path: Path to conditioning image or video used for visual context + + Returns: + str: Enhanced prompt incorporating visual details from the image + """ + dialog = prepare_dialog(image_or_video_path) + upsampled_prompt = run_chat_completion_vlm( + self.prompt_upsampler, dialog, max_gen_len=400, temperature=0.01, top_p=0.9, logprobs=False + ) + log.info(f"Upsampled prompt: {upsampled_prompt}") + return upsampled_prompt + + def _load_prompt_upsampler_model(self): + self.prompt_upsampler = create_vlm_prompt_upsampler( + checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir), + ) + + def _load_model(self): + self.model = load_model_by_config( + config_job_name=self.model_name, + config_file="cosmos_predict1/diffusion/config/config.py", + model_class=DiffusionWorldInterpolatorWModel, + ) + + @torch.inference_mode() + def _run_model( + self, + condition_latent: torch.Tensor | None = None, + negative_prompt_embedding: torch.Tensor | None = None, + num_of_loops: int = 1, + num_of_latent_overlap_list: list[int] = [1], + augment_sigma_list: list[float] = [0.001], + add_input_frames_guidance: float = 0, + skip_reencode: int = 0, + state_shape: list = None, + raw_data_batch: dict = None, + ) -> np.ndarray: + """Generate video frames using the diffusion model, supporting chunk processing for video extension. + + Args: + condition_latent: Latent tensor from conditioning image or video (optional for video extension). + negative_prompt_embedding: Optional embedding for negative prompt guidance. + num_of_loops: Number of loops for generating video segments. + num_of_latent_overlap_list: List of overlaps for latent conditions in each loop. + augment_sigma_list: List of sigma values for augmentation. + add_input_frames_guidance: Guidance strength for input frames. + skip_reencode: Whether to skip reencoding. + frame_index_start: Starting index for frame pairs. + num_frame_pairs: Number of frame pairs to process. + frame_stride: Stride between frame pairs. + is_interpolator_model: Whether the model is an interpolator. + input_frames: Input video frames for interpolation (optional). + + Returns: + np.ndarray: Generated video frames in shape (T, H, W, C). + """ + video_np_THWC, _, _ = generate_video_from_batch_with_loop( + model=self.model, + data_batch=raw_data_batch, + condition_latent=condition_latent, + num_of_loops=num_of_loops, + num_of_latent_overlap_list=num_of_latent_overlap_list, + guidance=self.guidance, + state_shape=state_shape, + num_steps=self.num_steps, + seed=self.seed, + is_negative_prompt=True if negative_prompt_embedding is not None else False, + visualize=False, + save_fig_path=None, + augment_sigma_list=augment_sigma_list, + add_input_frames_guidance=add_input_frames_guidance, + skip_reencode=skip_reencode, + ) + + return video_np_THWC + + def _run_tokenizer_encoding( + self, image_or_video_path: str, frame_index: int = 0, frame_stride: int = 1 + ) -> torch.Tensor: + """Encode image to latent space + + Args: + image_or_video_path: Path to conditioning image + frame_index: Starting frame index for encoding + frame_stride: Stride between frames for encoding + + Returns: + torch.Tensor: Latent tensor from tokenizer encoding + """ + condition_latent = get_condition_latent( + model=self.model, + input_image_or_video_path=image_or_video_path, + num_input_frames=self.num_input_frames, + state_shape=self.model.state_shape, + frame_index=frame_index, + frame_stride=frame_stride, + ) + + return condition_latent + + def _run_model_with_offload( + self, + prompt_embedding: torch.Tensor, + image_or_video_path: str, + negative_prompt_embedding: Optional[torch.Tensor] = None, + frame_index_start: int = 0, + num_frame_pairs: int = 1, + ) -> np.ndarray: + """Generate world representation with automatic model offloading. + + Wraps the core generation process with model loading/offloading logic + to minimize GPU memory usage during inference. + + Args: + prompt_embedding: Text embedding tensor from T5 encoder + image_or_video_path: Path to conditioning image or video + negative_prompt_embedding: Optional embedding for negative prompt guidance + frame_index_start: Starting index for frame pairs + num_frame_pairs: Number of frame pairs to process + + Returns: + np.ndarray: Generated world representation as numpy array + """ + if self.offload_tokenizer: + self._load_tokenizer() + + # Prepare video batch and state shape + raw_data_batch, state_shape = get_video_batch( + model=self.model, + prompt_embedding=prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + height=self.height, + width=self.width, + fps=self.fps, + num_video_frames=self.num_video_frames, + ) + + H, W = ( + state_shape[-2] * self.model.tokenizer.spatial_compression_factor, + state_shape[-1] * self.model.tokenizer.spatial_compression_factor, + ) + + input_path_format = image_or_video_path.split(".")[-1] + input_frames = read_video_or_image_into_frames_BCTHW( + image_or_video_path, + input_path_format=input_path_format, + H=H, + W=W, + ) + + num_frames = input_frames.shape[2] + num_frame_pairs = num_frame_pairs or num_frames // self.frame_stride + frame_stride = self.frame_stride + + video_output = [] + for frame_index in range(frame_index_start, num_frame_pairs): + print(f"Processing frame pair {frame_index + 1} / {num_frame_pairs}...") + + condition_latent = self._run_tokenizer_encoding(image_or_video_path, frame_index, frame_stride) + + video_np_THWC = self._run_model( + condition_latent=condition_latent, + negative_prompt_embedding=negative_prompt_embedding, + raw_data_batch=raw_data_batch, + state_shape=state_shape, + ) + + # Convert to tensor, rearrange, and normalize to [0, 1] + video_0_1 = einops.rearrange(torch.from_numpy(video_np_THWC), "t h w c -> c t h w") / 255.0 + + # Handle overlap by skipping the first frame of subsequent segments + if len(video_output) == 0: + video_output.append(video_0_1) + else: + video_output.append(video_0_1[:, 1:, :, :]) # Skip first frame to avoid duplication + + # Concatenate all segments + video_tensor = torch.cat(video_output, dim=1) # Shape: (C, total_num_frames, H, W) + + # Convert to NumPy array for guardrail: [T, H, W, C], uint8, [0, 255] + video_np = (video_tensor.permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() # Shape: (T, H, W, C) + + if self.offload_network: + self._offload_network() + if self.offload_tokenizer: + self._offload_tokenizer() + + return video_np + + def generate( + self, + prompt: str, + image_or_video_path: str, + negative_prompt: Optional[str] = None, + ) -> tuple[np.ndarray, str] | None: + """Generate video from text prompt and optional image. + + Pipeline steps: + 1. Run safety checks on input prompt + 2. Enhance prompt using upsampler if enabled + 3. Run safety checks on upsampled prompt if applicable + 4. Convert prompt to embeddings + 5. Generate video frames using diffusion + 6. Run safety checks and apply face blur on generated video frames + + Args: + prompt: Text description of desired video + image_or_video_path: Path to conditioning image or video + negative_prompt: Optional text to guide what not to generate + + Returns: + tuple: ( + Generated video frames as uint8 np.ndarray [T, H, W, C], + Final prompt used for generation (may be enhanced) + ), or None if content fails guardrail safety checks + """ + log.info(f"Run with prompt: {prompt}") + log.info(f"Run with image or video path: {image_or_video_path}") + log.info(f"Run with negative prompt: {negative_prompt}") + log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}") + + if not self.disable_guardrail and not self.enable_prompt_upsampler: + log.info("Run guardrail on prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical("Input text prompt is not safe") + return None + log.info("Pass guardrail on prompt") + else: + log.info("Run prompt upsampler on image or video, input prompt is not used") + prompt = self._run_prompt_upsampler_on_prompt_with_offload(image_or_video_path=image_or_video_path) + + if not self.disable_guardrail: + log.info("Run guardrail on upsampled prompt") + is_safe = self._run_guardrail_on_prompt_with_offload(prompt) + if not is_safe: + log.critical("Upsampled text prompt is not safe") + return None + log.info("Pass guardrail on upsampled prompt") + + log.info("Run text embedding on prompt") + if negative_prompt: + prompts = [prompt, negative_prompt] + else: + prompts = [prompt] + prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts) + prompt_embedding = prompt_embeddings[0] + negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None + log.info("Finish text embedding on prompt") + + # Generate video + log.info("Run generation") + video = self._run_model_with_offload( + prompt_embedding, + negative_prompt_embedding=negative_prompt_embedding, + image_or_video_path=image_or_video_path, + frame_index_start=self.frame_index_start, + num_frame_pairs=self.num_frame_pairs, + ) + log.info("Finish generation") + + if not self.disable_guardrail: + log.info("Run guardrail on generated video") + video = self._run_guardrail_on_video_with_offload(video) + if video is None: + log.critical("Generated video is not safe") + return None + log.info("Pass guardrail on generated video") + + return video, prompt diff --git a/cosmos_predict1/diffusion/inference/world_interpolator.py b/cosmos_predict1/diffusion/inference/world_interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..a7fd640a52438062c85919083225254fc58a98dc --- /dev/null +++ b/cosmos_predict1/diffusion/inference/world_interpolator.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CUDA_VISIBLE_DEVICES=1 python3 -m cosmos_predict1.diffusion.inference.world_interpolator \ + --checkpoint_dir checkpoints \ + --diffusion_transformer_dir Cosmos-Predict1-7B-WorldInterpolator \ + --input_image_or_video_path assets/diffusion/interpolation_example.mp4 \ + --num_input_frames 1 \ + --offload_prompt_upsampler \ + --video_save_name diffusion-world-interpolator-7b \ + --num_video_frames 10 \ + --num_frame_pairs 2 +""" + +import argparse +import os + +import torch + +from cosmos_predict1.diffusion.inference.inference_utils import add_common_arguments, check_input_frames, validate_args +from cosmos_predict1.diffusion.inference.world_generation_pipeline import DiffusionWorldInterpolatorGenerationPipeline +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.io import read_prompts_from_file, save_video + +# from cosmos_predict1.utils.visualize.video import save_img_or_video +torch.enable_grad(False) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Video to world generation demo script") + # Add common arguments + add_common_arguments(parser) + + # Add video2world specific arguments + parser.add_argument( + "--diffusion_transformer_dir", + type=str, + default="Cosmos-Predict1-7B-WorldInterpolator", + help="DiT model weights directory name relative to checkpoint_dir", + choices=[ + "Cosmos-Predict1-7B-WorldInterpolator", + "Cosmos-Predict1-7B-WorldInterpolator_post-trained", + ], + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + parser.add_argument( + "--input_image_or_video_path", + type=str, + help="Input video/image path for generating a single video", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=2, + help="The minimum number of input frames for world_interpolator predictions.", + ) + # parser.add_argument("--num_video_frames", type=int, default=118, help="numer of video frames to sample") + parser.add_argument("--pixel_chunk_duration", type=int, default=121, help="pixel chunk duration") + parser.add_argument( + "--frame_stride", + type=int, + default=1, + help="Specifies the gap between frames used for interpolation. A step_size of 1 means consecutive frame " + "pairs are treated as inputs (e.g., (x0, x1), (x1, x2)), while a step_size of 2 pairs frames with one " + "frame in between (e.g., (x0, x2), (x2, x4) are treated as input at a time). Increasing this value " + "results in interpolation over a larger temporal range. Default is 1.", + ) + parser.add_argument( + "--frame_index_start", + type=int, + default=0, + help="Specifies the gap between frames used for interpolation. A step_size of 1 means consecutive frame " + "pairs are treated as inputs (e.g., (x0, x1), (x1, x2)), while a step_size of 2 pairs frames with one " + "frame in between (e.g., (x0, x2), (x2, x4) are treated as input at a time). Increasing this value " + "results in interpolation over a larger temporal range. Default is 1.", + ) + parser.add_argument( + "--num_frame_pairs", + type=int, + default=None, + help="Limits the number of unique frame pairs processed for interpolation. By default (None), the interpolator " + "runs on all possible pairs extracted from the input video with the given step_size. If set to 1, only the first " + "frame pair is processed (e.g., (x0, x1) for step_size=1, (x0, x2) for step_size=2). Higher values allow processing more " + "pairs up to the maximum possible with the given step_size.", + ) + return parser.parse_args() + + +def demo(args): + """Run world-interpolator generation demo. + + This function handles the main video-to-world generation pipeline, including: + - Setting up the random seed for reproducibility + - Initializing the generation pipeline with the provided configuration + - Processing single or multiple prompts/images/videos from input + - Generating videos from prompts and images/videos + - Saving the generated videos and corresponding prompts to disk + + Args: + cfg (argparse.Namespace): Configuration namespace containing: + - Model configuration (checkpoint paths, model settings) + - Generation parameters (guidance, steps, dimensions) + - Input/output settings (prompts/images/videos, save paths) + - Performance options (model offloading settings) + + The function will save: + - Generated MP4 video files + - Text files containing the processed prompts + + If guardrails block the generation, a critical log message is displayed + and the function continues to the next prompt if available. + """ + # import ipdb; ipdb.set_trace() + misc.set_random_seed(args.seed) + inference_type = "world_interpolator" + validate_args(args, inference_type) + + if args.num_gpus > 1: + from megatron.core import parallel_state + + from cosmos_predict1.utils import distributed + + distributed.init() + parallel_state.initialize_model_parallel(context_parallel_size=args.num_gpus) + process_group = parallel_state.get_context_parallel_group() + + # Initialize video_interpolator generation model pipeline + pipeline = DiffusionWorldInterpolatorGenerationPipeline( + inference_type=inference_type, + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=args.diffusion_transformer_dir, + prompt_upsampler_dir=args.prompt_upsampler_dir, + enable_prompt_upsampler=not args.disable_prompt_upsampler, + offload_network=args.offload_diffusion_transformer, + offload_tokenizer=args.offload_tokenizer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_prompt_upsampler=args.offload_prompt_upsampler, + offload_guardrail_models=args.offload_guardrail_models, + disable_guardrail=args.disable_guardrail, + num_steps=args.num_steps, + height=args.height, + width=args.width, + fps=args.fps, + num_video_frames=args.num_video_frames, + num_input_frames=args.num_input_frames, + num_frame_pairs=args.num_frame_pairs, + frame_stride=args.frame_stride, + ) + + if args.num_gpus > 1: + pipeline.model.net.enable_context_parallel(process_group) + + # Handle multiple prompts if prompt file is provided + if args.batch_input_path: + log.info(f"Reading batch inputs from path: {args.batch_input_path}") + prompts = read_prompts_from_file(args.batch_input_path) + else: + # Single prompt case + prompts = [{"prompt": args.prompt, "visual_input": args.input_image_or_video_path}] + + os.makedirs(args.video_save_folder, exist_ok=True) + for i, input_dict in enumerate(prompts): + current_prompt = input_dict.get("prompt", None) + if current_prompt is None and args.disable_prompt_upsampler: + log.critical("Prompt is missing, skipping world generation.") + continue + current_image_or_video_path = input_dict.get("visual_input", None) + if current_image_or_video_path is None: + log.critical("Visual input is missing, skipping world generation.") + continue + + # Check input frames + if not check_input_frames(current_image_or_video_path, args.num_input_frames): + continue + + # Generate video + generated_output = pipeline.generate( + prompt=current_prompt, + image_or_video_path=current_image_or_video_path, + negative_prompt=args.negative_prompt, + ) + if generated_output is None: + log.critical("Guardrail blocked video2world generation.") + continue + video, prompt = generated_output + + # Save video + + video_save_path = os.path.join(args.video_save_folder, args.video_save_name + ".mp4") + prompt_save_path = os.path.join(args.video_save_folder, args.video_save_name + ".txt") + + save_video( + video=video, + fps=args.fps, + H=args.height, + W=args.width, + video_save_quality=5, + video_save_path=video_save_path, + ) + + with open(prompt_save_path, "w") as f: + f.write(prompt) + + log.info(f"Saved video to {video_save_path}") + log.info(f"Saved prompt to {prompt_save_path}") + + # clean up properly + if args.num_gpus > 1: + parallel_state.destroy_model_parallel() + import torch.distributed as dist + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_arguments() + demo(args) diff --git a/cosmos_predict1/diffusion/model/bu_model_world_interpolator.py b/cosmos_predict1/diffusion/model/bu_model_world_interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..9d8d6b31e91d8e612ad3561aebd61533a1e80e4c --- /dev/null +++ b/cosmos_predict1/diffusion/model/bu_model_world_interpolator.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel, broadcast_condition +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.modules.res_sampler import Sampler +from cosmos_predict1.utils import log, misc + +IS_PREPROCESSED_KEY = "is_preprocessed" +from cosmos_predict1.diffusion.modules.denoiser_scaling import EDMScaling +from cosmos_predict1.diffusion.types import DenoisePrediction + + +class DiffusionWorldInterpolatorWModel(DiffusionV2WModel): + def __init__(self, config): + super().__init__(config) + self.is_extend_model = True + self.num_valid_latents = config.latent_shape[1] - config.num_latents_to_drop + self.input_image_key = getattr(self.config, "input_image_key", None) + self.input_data_key = self.config.input_data_key + self.sampler = Sampler() # Added to resolve the AttributeError + self.scaling = EDMScaling(self.sigma_data) + + def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: VideoExtendCondition) -> DenoisePrediction: + xt = xt.to(**self.tensor_kwargs) + sigma = sigma.to(**self.tensor_kwargs) + c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) + condition_dict = { + k: v.to(self.precision) if isinstance(v, torch.Tensor) else v for k, v in condition.to_dict().items() + } + net_output = self.net( + x=batch_mul(c_in, xt), + timesteps=c_noise, + **condition_dict, + ) + logvar = self.model.logvar(c_noise) if hasattr(self.model, "logvar") else None + x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) + eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) + return DenoisePrediction(x0_pred, eps_pred, logvar) + + def _normalize_video_databatch_inplace(self, data_batch: Dict[str, Tensor]) -> None: + if self.input_data_key in data_batch: + if IS_PREPROCESSED_KEY not in data_batch or not data_batch[IS_PREPROCESSED_KEY]: + assert data_batch[self.input_data_key].dtype == torch.uint8, "Video data must be uint8." + data_batch[self.input_data_key] = data_batch[self.input_data_key].to(**self.tensor_kwargs) / 127.5 - 1.0 + data_batch[IS_PREPROCESSED_KEY] = True + + def _augment_image_dim_inplace(self, data_batch: Dict[str, Tensor]) -> None: + if self.input_image_key in data_batch: + if IS_PREPROCESSED_KEY not in data_batch or not data_batch[IS_PREPROCESSED_KEY]: + data_batch[self.input_image_key] = rearrange( + data_batch[self.input_image_key], "b c h w -> b c 1 h w" + ).contiguous() + data_batch[IS_PREPROCESSED_KEY] = True + + def is_image_batch(self, data_batch: Dict[str, Tensor]) -> bool: + is_image = self.input_image_key in data_batch + is_video = self.input_data_key in data_batch + assert is_image != is_video, "Batch must contain either image or video data, not both or neither." + return is_image + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Optional[int] = None + ) -> VideoExtendCondition: + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.debug( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + B, C, T, H, W = latent_state.shape + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + if condition.video_cond_bool: + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: + condition.condition_video_input_mask = zeros_padding + return condition + + def _get_conditions( + self, + data_batch: dict, + is_negative_prompt: bool = False, + condition_latent: Optional[torch.Tensor] = None, + num_condition_t: Optional[int] = None, + add_input_frames_guidance: bool = False, + ): + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + assert condition.gt_latent.allclose(uncondition.gt_latent) + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) + return condition, uncondition + + def _augment_noise_with_latent( + self, + xt: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_augment_sigma: float = 0.001, + seed: int = 1, + ) -> tuple[Tensor, Tensor, Tensor]: + augment_sigma = condition_augment_sigma + latent = condition.gt_latent + indicator = condition.condition_video_indicator + if augment_sigma >= sigma: + indicator = torch.zeros_like(indicator) + noise = misc.arch_invariant_rand(latent.shape, torch.float32, self.tensor_kwargs["device"], seed) + augment_latent = latent + noise * augment_sigma + augment_latent = self.scheduler.precondition_inputs(augment_latent, augment_sigma) + augment_latent_unscaled = self._reverse_precondition_input(augment_latent, sigma) + if self.net.is_context_parallel_enabled: + latent = split_inputs_cp(condition.gt_latent, seq_dim=2, cp_group=self.net.cp_group) + indicator = split_inputs_cp(indicator, seq_dim=2, cp_group=self.net.cp_group) + augment_latent_unscaled = split_inputs_cp(augment_latent_unscaled, seq_dim=2, cp_group=self.net.cp_group) + new_xt = indicator * augment_latent_unscaled + (1 - indicator) * xt + return new_xt, latent, indicator + + def _reverse_precondition_input(self, xt: Tensor, sigma: Tensor) -> Tensor: + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + xt_unscaled = xt / c_in + return xt_unscaled + + def _reverse_precondition_output(self, latent: Tensor, xt: Tensor, sigma: Tensor) -> Tensor: + sigma_data = self.scheduler.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + latent_unscaled = latent / c_out - c_skip * xt + return latent_unscaled + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + seed_inference: int = 1, + ) -> Callable: + assert condition_latent is not None, "condition_latent must be provided for video generation." + condition, uncondition = self._get_conditions( + data_batch, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + add_input_frames_guidance=add_input_frames_guidance, + ) + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_xt, cond_latent, cond_indicator = self._augment_noise_with_latent( + noise_x, + sigma, + condition, + condition_augment_sigma=condition_video_augment_sigma_in_inference or 0.001, + seed=seed_inference, + ) + cond_pred = self.denoise(cond_xt, sigma, condition) + cond_x0 = cond_pred.x0_pred_replaced if hasattr(cond_pred, "x0_pred_replaced") else cond_pred.x0 + uncond_xt, _, _ = self._augment_noise_with_latent( + noise_x, + sigma, + uncondition, + condition_augment_sigma=condition_video_augment_sigma_in_inference or 0.001, + seed=seed_inference, + ) + uncond_pred = self.denoise(uncond_xt, sigma, uncondition) + uncond_x0 = uncond_pred.x0_pred_replaced if hasattr(uncond_pred, "x0_pred_replaced") else uncond_pred.x0 + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + return_noise: bool = False, + ) -> Tensor | Tuple[Tensor, Tensor]: + self._normalize_video_databatch_inplace(data_batch) + # self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + assert condition_latent is not None, "condition_latent should be provided" + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + seed_inference=seed, + ) + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), torch.float32, self.tensor_kwargs["device"], seed + ) + * 80 + ) + if self.net.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=80) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + if return_noise: + if self.net.is_context_parallel_enabled: + x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + return samples, x_sigma_max / 80 + return samples diff --git a/cosmos_predict1/diffusion/model/model_gen3c.py b/cosmos_predict1/diffusion/model/model_gen3c.py new file mode 100644 index 0000000000000000000000000000000000000000..2c77228d50013b89559f1f7bd3a831101861b5fc --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_gen3c.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel, broadcast_condition + + +class DiffusionGen3CModel(DiffusionV2WModel): + def __init__(self, config): + super().__init__(config) + self.frame_buffer_max = config.frame_buffer_max + self.chunk_size = 121 + + def encode_warped_frames( + self, + condition_state: torch.Tensor, + condition_state_mask: torch.Tensor, + dtype: torch.dtype, + ): + + assert condition_state.dim() == 6 + condition_state_mask = (condition_state_mask * 2 - 1).repeat(1, 1, 1, 3, 1, 1) + latent_condition = [] + for i in range(condition_state.shape[2]): + current_video_latent = self.encode( + condition_state[:, :, i].permute(0, 2, 1, 3, 4).to(dtype) + ).contiguous() # 1, 16, 8, 88, 160 + + current_mask_latent = self.encode( + condition_state_mask[:, :, i].permute(0, 2, 1, 3, 4).to(dtype) + ).contiguous() + latent_condition.append(current_video_latent) + latent_condition.append(current_mask_latent) + for _ in range(self.frame_buffer_max - condition_state.shape[2]): + latent_condition.append(torch.zeros_like(current_video_latent)) + latent_condition.append(torch.zeros_like(current_mask_latent)) + + latent_condition = torch.cat(latent_condition, dim=1) + return latent_condition + + def _get_conditions( + self, + data_batch: dict, + is_negative_prompt: bool = False, + condition_latent: Optional[torch.Tensor] = None, + num_condition_t: Optional[int] = None, + add_input_frames_guidance: bool = False, + ): + """Get the conditions for the model. + + Args: + data_batch: Input data dictionary + is_negative_prompt: Whether to use negative prompting + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + add_input_frames_guidance: Whether to apply guidance to input frames + + Returns: + condition: Input conditions + uncondition: Conditions removed/reduced to minimum (unconditioned) + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + # encode warped frames + condition_state, condition_state_mask = ( + data_batch["condition_state"], + data_batch["condition_state_mask"], + ) + latent_condition = self.encode_warped_frames( + condition_state, condition_state_mask, self.tensor_kwargs["dtype"] + ) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + condition = self.add_condition_pose(latent_condition, condition) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + uncondition = self.add_condition_pose(latent_condition, uncondition, drop_out_latent = True) + assert condition.gt_latent.allclose(uncondition.gt_latent) + + # For inference, check if parallel_state is initialized + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) + + return condition, uncondition + + def add_condition_pose(self, latent_condition: torch.Tensor, condition: VideoExtendCondition, + drop_out_latent: bool = False) -> VideoExtendCondition: + """Add pose condition to the condition object. For camera control model + Args: + data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + if drop_out_latent: + condition.condition_video_pose = torch.zeros_like(latent_condition.contiguous()) + else: + condition.condition_video_pose = latent_condition.contiguous() + + to_cp = self.net.is_context_parallel_enabled + + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition diff --git a/cosmos_predict1/diffusion/model/model_t2w.py b/cosmos_predict1/diffusion/model/model_t2w.py new file mode 100644 index 0000000000000000000000000000000000000000..b2910e09ec3e9779f6b962251a3bba8c70dddbb8 --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_t2w.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from diffusers import EDMEulerScheduler +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import BaseVideoCondition +from cosmos_predict1.diffusion.module import parallel +from cosmos_predict1.diffusion.module.blocks import FourierFeatures +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.module.pretrained_vae import BaseVAE +from cosmos_predict1.diffusion.training.utils.layer_control.peft_control_config_parser import LayerControlConfigParser +from cosmos_predict1.diffusion.training.utils.peft.peft import add_lora_layers, setup_lora_requires_grad +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.distributed import get_rank +from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate + + +class DiffusionT2WModel(torch.nn.Module): + """Text-to-world diffusion model that generates video frames from text descriptions. + + This model implements a diffusion-based approach for generating videos conditioned on text input. + It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling, + and classifier-free guidance. + """ + + def __init__(self, config): + """Initialize the diffusion model. + + Args: + config: Configuration object containing model parameters and architecture settings + """ + super().__init__() + # Initialize trained_data_record with defaultdict, key: image, video, iteration + self.config = config + + self.precision = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[config.precision] + self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} + log.debug(f"DiffusionModel: precision {self.precision}") + # Timer passed to network to detect slow ranks. + # 1. set data keys and data information + self.sigma_data = config.sigma_data + self.state_shape = list(config.latent_shape) + self.setup_data_key() + + # 2. setup up diffusion processing and scaling~(pre-condition), sampler + self.scheduler = EDMEulerScheduler(sigma_max=80, sigma_min=0.0002, sigma_data=self.sigma_data) + self.tokenizer = None + self.model = None + + @property + def net(self): + return self.model.net + + @property + def conditioner(self): + return self.model.conditioner + + @property + def logvar(self): + return self.model.logvar + + def set_up_tokenizer(self, tokenizer_dir: str): + self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer) + self.tokenizer.load_weights(tokenizer_dir) + if hasattr(self.tokenizer, "reset_dtype"): + self.tokenizer.reset_dtype() + + @misc.timer("DiffusionModel: set_up_model") + def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format): + """Initialize the core model components including network, conditioner and logvar.""" + self.model = self.build_model() + if self.config.peft_control and self.config.peft_control.enabled: + log.info("Setting up LoRA layers") + peft_control_config_parser = LayerControlConfigParser(config=self.config.peft_control) + peft_control_config = peft_control_config_parser.parse() + add_lora_layers(self.model, peft_control_config) + num_lora_params = setup_lora_requires_grad(self.model) + self.model.requires_grad_(False) + if num_lora_params == 0: + raise ValueError("No LoRA parameters found. Please check the model configuration.") + self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) + + def build_model(self) -> torch.nn.ModuleDict: + """Construct the model's neural network components. + + Returns: + ModuleDict containing the network, conditioner and logvar components + """ + config = self.config + net = lazy_instantiate(config.net) + conditioner = lazy_instantiate(config.conditioner) + logvar = torch.nn.Sequential( + FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) + ) + + return torch.nn.ModuleDict( + { + "net": net, + "conditioner": conditioner, + "logvar": logvar, + } + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """Encode input state into latent representation using VAE. + + Args: + state: Input tensor to encode + + Returns: + Encoded latent representation scaled by sigma_data + """ + return self.tokenizer.encode(state) * self.sigma_data + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """Decode latent representation back to pixel space using VAE. + + Args: + latent: Latent tensor to decode + + Returns: + Decoded tensor in pixel space + """ + return self.tokenizer.decode(latent / self.sigma_data) + + def setup_data_key(self) -> None: + """Configure input data keys for video and image data.""" + self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model + + def generate_samples_from_batch( + self, + data_batch: dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: tuple | None = None, + n_sample: int | None = 1, + is_negative_prompt: bool = False, + num_steps: int = 35, + ) -> Tensor: + """Generate samples from a data batch using diffusion sampling. + + This function generates samples from either image or video data batches using diffusion sampling. + It handles both conditional and unconditional generation with classifier-free guidance. + + Args: + data_batch (dict): Raw data batch from the training data loader + guidance (float, optional): Classifier-free guidance weight. Defaults to 1.5. + seed (int, optional): Random seed for reproducibility. Defaults to 1. + state_shape (tuple | None, optional): Shape of the state tensor. Uses self.state_shape if None. Defaults to None. + n_sample (int | None, optional): Number of samples to generate. Defaults to 1. + is_negative_prompt (bool, optional): Whether to use negative prompt for unconditional generation. Defaults to False. + num_steps (int, optional): Number of diffusion sampling steps. Defaults to 35. + + Returns: + Tensor: Generated samples after diffusion sampling + """ + condition, uncondition = self._get_conditions(data_batch, is_negative_prompt) + + self.scheduler.set_timesteps(num_steps) + + xt = torch.randn(size=(n_sample,) + tuple(state_shape)) * self.scheduler.init_noise_sigma + to_cp = self.net.is_context_parallel_enabled + if to_cp: + xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) + + for t in self.scheduler.timesteps: + xt = xt.to(**self.tensor_kwargs) + xt_scaled = self.scheduler.scale_model_input(xt, timestep=t) + # Predict the noise residual + t = t.to(**self.tensor_kwargs) + net_output_cond = self.net(x=xt_scaled, timesteps=t, **condition.to_dict()) + net_output_uncond = self.net(x=xt_scaled, timesteps=t, **uncondition.to_dict()) + net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) + # Compute the previous noisy sample x_t -> x_t-1 + xt = self.scheduler.step(net_output, t, xt).prev_sample + samples = xt + + if to_cp: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + return samples + + def _get_conditions( + self, + data_batch: dict, + is_negative_prompt: bool = False, + ): + """Get the conditions for the model. + + Args: + data_batch: Input data dictionary + is_negative_prompt: Whether to use negative prompting + + Returns: + condition: Input conditions + uncondition: Conditions removed/reduced to minimum (unconditioned) + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) + + return condition, uncondition + + +def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: + condition_kwargs = {} + for k, v in condition.to_dict().items(): + if isinstance(v, torch.Tensor): + assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" + condition_kwargs[k] = parallel.broadcast(v, to_tp=to_tp, to_cp=to_cp) + condition = type(condition)(**condition_kwargs) + return condition diff --git a/cosmos_predict1/diffusion/model/model_t2w_multiview.py b/cosmos_predict1/diffusion/model/model_t2w_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..8008cac54bd74d4ef835fca63218dffd5fc3fe31 --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_t2w_multiview.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from einops import rearrange +from torch import Tensor + +from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.utils import log, misc + + +class DiffusionMultiviewT2WModel(DiffusionT2WModel): + def __init__(self, config): + super().__init__(config) + self.n_views = config.net.n_views + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + encoded_state = self.tokenizer.encode(state) + encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data + return encoded_state + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + decoded_state = self.tokenizer.decode(latent / self.sigma_data) + decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return decoded_state + + def generate_samples_from_batch( + self, + data_batch: dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: tuple | None = None, + n_sample: int | None = 1, + is_negative_prompt: bool = False, + num_steps: int = 35, + ) -> Tensor: + """Generate samples from a data batch using diffusion sampling. + + This function generates samples from either image or video data batches using diffusion sampling. + It handles both conditional and unconditional generation with classifier-free guidance. + + Args: + data_batch (dict): Raw data batch from the training data loader + guidance (float, optional): Classifier-free guidance weight. Defaults to 1.5. + seed (int, optional): Random seed for reproducibility. Defaults to 1. + state_shape (tuple | None, optional): Shape of the state tensor. Uses self.state_shape if None. Defaults to None. + n_sample (int | None, optional): Number of samples to generate. Defaults to 1. + is_negative_prompt (bool, optional): Whether to use negative prompt for unconditional generation. Defaults to False. + num_steps (int, optional): Number of diffusion sampling steps. Defaults to 35. + + Returns: + Tensor: Generated samples after diffusion sampling + """ + condition, uncondition = self._get_conditions(data_batch, is_negative_prompt) + + self.scheduler.set_timesteps(num_steps) + + xt = torch.randn(size=(n_sample,) + tuple(state_shape)) * self.scheduler.init_noise_sigma + + to_cp = self.net.is_context_parallel_enabled + if to_cp: + xt = rearrange(xt, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) + xt = rearrange(xt, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + for t in self.scheduler.timesteps: + xt = xt.to(**self.tensor_kwargs) + xt_scaled = self.scheduler.scale_model_input(xt, timestep=t) + # Predict the noise residual + t = t.to(**self.tensor_kwargs) + net_output_cond = self.net(x=xt_scaled, timesteps=t, **condition.to_dict()) + net_output_uncond = self.net(x=xt_scaled, timesteps=t, **uncondition.to_dict()) + net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) + # Compute the previous noisy sample x_t -> x_t-1 + xt = self.scheduler.step(net_output, t, xt).prev_sample + samples = xt + + if to_cp: + samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + return samples diff --git a/cosmos_predict1/diffusion/model/model_v2w.py b/cosmos_predict1/diffusion/model/model_v2w.py new file mode 100644 index 0000000000000000000000000000000000000000..246485e94c8f4400149469fbce8b0d2965de1271 --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_v2w.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition +from cosmos_predict1.diffusion.model.model_t2w import DiffusionT2WModel, broadcast_condition +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.utils import log, misc + + +class DiffusionV2WModel(DiffusionT2WModel): + def __init__(self, config): + super().__init__(config) + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Optional[int] = None + ) -> VideoExtendCondition: + """Adds conditioning masks to VideoExtendCondition object. + + Creates binary indicators and input masks for conditional video generation. + + Args: + latent_state: Input latent tensor (B,C,T,H,W) + condition: VideoExtendCondition object to update + num_condition_t: Number of frames to condition on + + Returns: + Updated VideoExtendCondition with added masks: + - condition_video_indicator: Binary tensor marking condition regions + - condition_video_input_mask: Input mask for network + - gt_latent: Ground truth latent tensor + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.debug( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + return condition + + def generate_samples_from_batch( + self, + data_batch: dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: tuple | None = None, + n_sample: int | None = 1, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Optional[torch.Tensor] = None, + num_condition_t: Optional[int] = None, + condition_augment_sigma: float = None, + add_input_frames_guidance: bool = False, + ) -> Tensor: + """Generates video samples conditioned on input frames. + + Args: + data_batch: Input data dictionary + guidance: Classifier-free guidance scale + seed: Random seed for reproducibility + state_shape: Shape of output tensor (defaults to model's state shape) + n_sample: Number of samples to generate (defaults to batch size) + is_negative_prompt: Whether to use negative prompting + num_steps: Number of denoising steps + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + condition_augment_sigma: Noise level for condition augmentation + add_input_frames_guidance: Whether to apply guidance to input frames + + Returns: + Generated video samples tensor + """ + assert condition_latent is not None, "condition_latent should be provided" + condition, uncondition = self._get_conditions( + data_batch, is_negative_prompt, condition_latent, num_condition_t, add_input_frames_guidance + ) + + self.scheduler.set_timesteps(num_steps) + if n_sample is None: + n_sample = condition_latent.shape[0] + xt = torch.randn(size=(n_sample,) + tuple(state_shape), **self.tensor_kwargs) * self.scheduler.init_noise_sigma + + to_cp = self.net.is_context_parallel_enabled + if to_cp: + xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) + + for t in self.scheduler.timesteps: + self.scheduler._init_step_index(t) + sigma = self.scheduler.sigmas[self.scheduler.step_index].to(**self.tensor_kwargs) + # Form new noise from latent + xt = xt.to(**self.tensor_kwargs) + new_xt, latent, indicator = self._augment_noise_with_latent( + xt, sigma, condition, condition_augment_sigma=condition_augment_sigma, seed=seed + ) + new_xt = new_xt.to(**self.tensor_kwargs) + new_xt_scaled = self.scheduler.scale_model_input(new_xt, timestep=t) + # Predict the noise residual + t = t.to(**self.tensor_kwargs) + net_output_cond = self.net(x=new_xt_scaled, timesteps=t, **condition.to_dict()) + net_output_uncond = self.net(x=new_xt_scaled, timesteps=t, **uncondition.to_dict()) + net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) + # Replace indicated output with latent + latent_unscaled = self._reverse_precondition_output(latent, xt=new_xt, sigma=sigma) + new_output = indicator * latent_unscaled + (1 - indicator) * net_output + # Compute the previous noisy sample x_t -> x_t-1 + xt = self.scheduler.step(new_output, t, new_xt).prev_sample + samples = xt + + if to_cp: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + return samples + + def _get_conditions( + self, + data_batch: dict, + is_negative_prompt: bool = False, + condition_latent: Optional[torch.Tensor] = None, + num_condition_t: Optional[int] = None, + add_input_frames_guidance: bool = False, + ): + """Get the conditions for the model. + + Args: + data_batch: Input data dictionary + is_negative_prompt: Whether to use negative prompting + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + add_input_frames_guidance: Whether to apply guidance to input frames + + Returns: + condition: Input conditions + uncondition: Conditions removed/reduced to minimum (unconditioned) + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + assert condition.gt_latent.allclose(uncondition.gt_latent) + + # For inference, check if parallel_state is initialized + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) + + return condition, uncondition + + def _augment_noise_with_latent( + self, + xt: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_augment_sigma: float = 0.001, + seed: int = 1, + ) -> tuple[Tensor, Tensor, Tensor]: + """Augments the conditional frames with noise during inference. + + Args: + xt (Tensor): noise + sigma (Tensor): noise level for the generation region + condition (VideoExtendCondition): condition object + condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. + condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. + condition_augment_sigma (float): sigma for condition video augmentation in inference + seed (int): random seed for reproducibility + Returns: + new_xt (Tensor): new latent-augmented noise tensor in shape B,C,T,H,W + latent (Tensor): ground-truth latent tensor in shape B,C,T,H,W + indicator (Tensor): ground-truth latent binary indicator tensor in shape B,C,T,H,W + + """ + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + augment_sigma = condition_augment_sigma + latent = condition.gt_latent + indicator = condition.condition_video_indicator + if augment_sigma >= sigma: + indicator = torch.zeros_like(indicator) + # Now apply the augment_sigma to the gt_latent + noise = misc.arch_invariant_rand( + latent.shape, + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + augment_latent = latent + noise * augment_sigma + augment_latent = self.scheduler.precondition_inputs(augment_latent, augment_sigma) + augment_latent_unscaled = self._reverse_precondition_input(augment_latent, sigma) + if self.net.is_context_parallel_enabled: + latent = split_inputs_cp(condition.gt_latent, seq_dim=2, cp_group=self.net.cp_group) + indicator = split_inputs_cp(indicator, seq_dim=2, cp_group=self.net.cp_group) + augment_latent_unscaled = split_inputs_cp(augment_latent_unscaled, seq_dim=2, cp_group=self.net.cp_group) + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_xt = indicator * augment_latent_unscaled + (1 - indicator) * xt + return new_xt, latent, indicator + + def _reverse_precondition_input(self, xt: Tensor, sigma: Tensor) -> Tensor: + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + xt_unscaled = xt / c_in + return xt_unscaled + + def _reverse_precondition_output(self, latent: Tensor, xt: Tensor, sigma: Tensor) -> Tensor: + sigma_data = self.scheduler.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + latent_unscaled = (latent - c_skip * xt) / c_out + return latent_unscaled diff --git a/cosmos_predict1/diffusion/model/model_v2w_multiview.py b/cosmos_predict1/diffusion/model/model_v2w_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..ece347fd51e50163b9d0d8c8eb7cbf3a57cce58f --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_v2w_multiview.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from einops import rearrange +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.utils import log, misc + + +class DiffusionMultiviewV2WModel(DiffusionV2WModel): + def __init__(self, config): + super().__init__(config) + self.n_views = config.net.n_views + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + encoded_state = self.tokenizer.encode(state) + encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data + return encoded_state + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + decoded_state = self.tokenizer.decode(latent / self.sigma_data) + decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return decoded_state + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + + condition_video_indicator = rearrange( + condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views + ) + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + condition_video_indicator = rearrange( + condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views + ) + + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + return condition + + def generate_samples_from_batch( + self, + data_batch: dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Optional[torch.Tensor] = None, + num_condition_t: Optional[int] = None, + condition_augment_sigma: float = None, + add_input_frames_guidance: bool = False, + ) -> Tensor: + """Generates video samples conditioned on input frames. + + Args: + data_batch: Input data dictionary + guidance: Classifier-free guidance scale + seed: Random seed for reproducibility + state_shape: Shape of output tensor (defaults to model's state shape) + n_sample: Number of samples to generate (defaults to batch size) + is_negative_prompt: Whether to use negative prompting + num_steps: Number of denoising steps + condition_latent: Conditioning frames tensor (B,C,T,H,W) + num_condition_t: Number of frames to condition on + condition_augment_sigma: Noise level for condition augmentation + add_input_frames_guidance: Whether to apply guidance to input frames + + Returns: + Generated video samples tensor + """ + assert condition_latent is not None, "condition_latent should be provided" + condition, uncondition = self._get_conditions( + data_batch, is_negative_prompt, condition_latent, num_condition_t, add_input_frames_guidance + ) + + self.scheduler.set_timesteps(num_steps) + if n_sample is None: + n_sample = condition_latent.shape[0] + xt = torch.randn(size=(n_sample,) + tuple(state_shape), **self.tensor_kwargs) * self.scheduler.init_noise_sigma + + to_cp = self.net.is_context_parallel_enabled + if to_cp: + xt = rearrange(xt, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) + xt = rearrange(xt, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + for t in self.scheduler.timesteps: + self.scheduler._init_step_index(t) + sigma = self.scheduler.sigmas[self.scheduler.step_index].to(**self.tensor_kwargs) + # Form new noise from latent + new_xt, latent, indicator = self._augment_noise_with_latent( + xt, sigma, condition, condition_augment_sigma=condition_augment_sigma, seed=seed + ) + new_xt = new_xt.to(**self.tensor_kwargs) + new_xt_scaled = self.scheduler.scale_model_input(new_xt, timestep=t) + # Predict the noise residual + t = t.to(**self.tensor_kwargs) + net_output_cond = self.net(x=new_xt_scaled, timesteps=t, **condition.to_dict()) + net_output_uncond = self.net(x=new_xt_scaled, timesteps=t, **uncondition.to_dict()) + net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) + # Replace indicated output with latent + latent_unscaled = self._reverse_precondition_output(latent, xt=new_xt, sigma=sigma) + new_output = indicator * latent_unscaled + (1 - indicator) * net_output + # Compute the previous noisy sample x_t -> x_t-1 + xt = self.scheduler.step(new_output, t, new_xt).prev_sample + samples = xt + + if to_cp: + samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + return samples + + def _augment_noise_with_latent( + self, + xt: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_augment_sigma: float = 0.001, + seed: int = 1, + ) -> tuple[Tensor, Tensor, Tensor]: + """Augments the conditional frames with noise during inference. + + Args: + xt (Tensor): noise + sigma (Tensor): noise level for the generation region + condition (VideoExtendCondition): condition object + condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. + condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. + condition_augment_sigma (float): sigma for condition video augmentation in inference + seed (int): random seed for reproducibility + Returns: + new_xt (Tensor): new latent-augmented noise tensor in shape B,C,T,H,W + latent (Tensor): ground-truth latent tensor in shape B,C,T,H,W + indicator (Tensor): ground-truth latent binary indicator tensor in shape B,C,T,H,W + + """ + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + augment_sigma = condition_augment_sigma + latent = condition.gt_latent + indicator = condition.condition_video_indicator + if augment_sigma >= sigma: + indicator = torch.zeros_like(indicator) + # Now apply the augment_sigma to the gt_latent + noise = misc.arch_invariant_rand( + latent.shape, + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + augment_latent = latent + noise * augment_sigma + augment_latent = self.scheduler.precondition_inputs(augment_latent, augment_sigma) + augment_latent_unscaled = self._reverse_precondition_input(augment_latent, sigma) + if self.net.is_context_parallel_enabled: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + indicator = rearrange(indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + augment_latent_unscaled = rearrange( + augment_latent_unscaled, "B C (V T) H W -> (B V) C T H W", V=self.n_views + ) + + latent = split_inputs_cp(latent, seq_dim=2, cp_group=self.net.cp_group) + indicator = split_inputs_cp(indicator, seq_dim=2, cp_group=self.net.cp_group) + augment_latent_unscaled = split_inputs_cp(augment_latent_unscaled, seq_dim=2, cp_group=self.net.cp_group) + + latent = rearrange(latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + indicator = rearrange(indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + augment_latent_unscaled = rearrange( + augment_latent_unscaled, "(B V) C T H W -> B C (V T) H W", V=self.n_views + ) + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_xt = indicator * augment_latent_unscaled + (1 - indicator) * xt + return new_xt, latent, indicator diff --git a/cosmos_predict1/diffusion/model/model_world_interpolator.py b/cosmos_predict1/diffusion/model/model_world_interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..c82c7ab8c5e975e644fab120f749d30de93ff370 --- /dev/null +++ b/cosmos_predict1/diffusion/model/model_world_interpolator.py @@ -0,0 +1,623 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from statistics import NormalDist +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition +from cosmos_predict1.diffusion.config.base.conditioner import VideoCondBoolConfig +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.model.model_v2w import DiffusionV2WModel, broadcast_condition +from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.modules.res_sampler import Sampler +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.models.model import _broadcast +from cosmos_predict1.utils import log, misc + +IS_PREPROCESSED_KEY = "is_preprocessed" +from dataclasses import dataclass, fields + +from cosmos_predict1.diffusion.modules.denoiser_scaling import EDMScaling +from cosmos_predict1.diffusion.training.modules.edm_sde import EDMSDE +from cosmos_predict1.diffusion.types import DenoisePrediction + + +@dataclass +class VideoDenoisePrediction: + x0: torch.Tensor # clean data prediction + eps: Optional[torch.Tensor] = None # noise prediction + logvar: Optional[torch.Tensor] = None # log variance of noise prediction + net_in: Optional[torch.Tensor] = None # input to the network + net_x0_pred: Optional[torch.Tensor] = None # prediction of x0 from the network + xt: Optional[torch.Tensor] = None # input to the network, before multiply with c_in + x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent + + +@dataclass +class CosmosCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + padding_mask: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +class DiffusionWorldInterpolatorWModel(DiffusionV2WModel): + def __init__(self, config): + super().__init__(config) + self.is_extend_model = True + self.num_valid_latents = config.latent_shape[1] - config.num_latents_to_drop + self.setup_data_key() # Initialize input_data_key and input_image_key + self.sampler = Sampler() + self.scaling = EDMScaling(self.sigma_data) + self.sde = EDMSDE( + p_mean=0.0, + p_std=1.0, + sigma_max=80, + sigma_min=0.0002, + ) + + def setup_data_key(self) -> None: + """Initialize data keys for image and video inputs.""" + self.input_data_key = self.config.input_data_key + self.input_image_key = self.config.input_image_key + + def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool: + """Determine if the data batch is an image batch or a video batch. + + Args: + data_batch (dict[str, Tensor]): Input data batch. + + Returns: + bool: True if the batch is an image batch, False if it is a video batch. + + Raises: + AssertionError: If both or neither of input_image_key and input_data_key are present. + """ + is_image = self.input_image_key in data_batch + is_video = self.input_data_key in data_batch + assert ( + is_image != is_video + ), "Only one of the input_image_key or input_data_key should be present in the data_batch." + return is_image + + def _normalize_video_databatch_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: + """Normalizes video data in-place on a CUDA device to reduce data loading overhead. + + Args: + data_batch (dict[str, Tensor]): Dictionary containing the video data. + input_key (str, optional): Key for the video data in the batch. Defaults to self.input_data_key. + + Side Effects: + Modifies the video data tensor in-place to scale from [0, 255] to [-1, 1]. + """ + input_key = self.input_data_key if input_key is None else input_key + if input_key in data_batch: + if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: + assert torch.is_floating_point(data_batch[input_key]), "Video data is not in float format." + assert torch.all( + (data_batch[input_key] >= -1.0001) & (data_batch[input_key] <= 1.0001) + ), f"Video data is not in the range [-1, 1]. get data range [{data_batch[input_key].min()}, {data_batch[input_key].max()}]" + else: + assert data_batch[input_key].dtype == torch.uint8, "Video data is not in uint8 format." + data_batch[input_key] = data_batch[input_key].to(**self.tensor_kwargs) / 127.5 - 1.0 + data_batch[IS_PREPROCESSED_KEY] = True + + def _augment_image_dim_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: + """Augments image data in-place by adding a temporal dimension. + + Args: + data_batch (dict[str, Tensor]): Dictionary containing the image data. + input_key (str, optional): Key for the image data in the batch. Defaults to self.input_image_key. + + Side Effects: + Modifies the image data tensor in-place to add a temporal dimension (B,C,H,W -> B,C,1,H,W). + """ + input_key = self.input_image_key if input_key is None else input_key + if input_key in data_batch: + if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: + assert ( + data_batch[input_key].shape[2] == 1 + ), f"Image data is claimed be augmented while its shape is {data_batch[input_key].shape}" + return + else: + data_batch[input_key] = rearrange(data_batch[input_key], "b c h w -> b c 1 h w").contiguous() + data_batch[IS_PREPROCESSED_KEY] = True + + def normalize_condition_latent(self, condition_latent: torch.Tensor) -> torch.Tensor: + """Normalize the condition latent tensor to have zero mean and unit variance.""" + condition_latent_2D = rearrange(condition_latent, "b c t h w -> b c t (h w)") + mean = condition_latent_2D.mean(dim=-1) + std = condition_latent_2D.std(dim=-1) + mean = mean.unsqueeze(-1).unsqueeze(-1) + std = std.unsqueeze(-1).unsqueeze(-1) + condition_latent = (condition_latent - mean) / std + return condition_latent + + def draw_augment_sigma_and_epsilon( + self, size: int, condition: VideoExtendCondition, p_mean: float, p_std: float, multiplier: float + ) -> Tuple[Tensor, Tensor]: + """Draw sigma and epsilon for augmenting conditional latent frames.""" + is_video_batch = condition.data_type == DataType.VIDEO + del condition + batch_size = size[0] + epsilon = torch.randn(size, **self.tensor_kwargs) + + gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) + cdf_vals = np.random.uniform(size=(batch_size)) + samples_interval_gaussian = [gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] + + log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") + sigma_B = torch.exp(log_sigma).to(**self.tensor_kwargs) + + sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) + epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) + return sigma_B, epsilon + + def augment_conditional_latent_frames( + self, + condition: VideoExtendCondition, + cfg_video_cond_bool: VideoCondBoolConfig, + gt_latent: Tensor, + condition_video_augment_sigma_in_inference: float = 0.001, + sigma: Tensor = None, + seed_inference: int = 1, + ) -> Union[VideoExtendCondition, Tensor]: + """Augment the condition input with noise.""" + if cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma": + augment_sigma, _ = self.draw_augment_sigma_and_epsilon( + gt_latent.shape, + condition, + cfg_video_cond_bool.augment_sigma_sample_p_mean, + cfg_video_cond_bool.augment_sigma_sample_p_std, + cfg_video_cond_bool.augment_sigma_sample_multiplier, + ) + noise = torch.randn(*gt_latent.shape, **self.tensor_kwargs) + elif cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma_fixed": + log.debug( + f"condition_video_augment_sigma_in_inference={condition_video_augment_sigma_in_inference}, sigma={sigma.flatten()[0]}" + ) + assert ( + condition_video_augment_sigma_in_inference is not None + ), "condition_video_augment_sigma_in_inference should be provided" + augment_sigma = condition_video_augment_sigma_in_inference + + if augment_sigma >= sigma.flatten()[0]: + log.debug("augment_sigma larger than sigma or other frame, remove condition") + condition.condition_video_indicator = condition_video_indicator * 0 + + augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs) + noise = misc.arch_invariant_rand( + gt_latent.shape, + torch.float32, + self.tensor_kwargs["device"], + seed_inference, + ) + else: + raise ValueError(f"does not support {cfg_video_cond_bool.apply_corruption_to_condition_region}") + + augment_latent = gt_latent + noise * augment_sigma.view(-1, 1, 1, 1, 1) + _, _, c_in_augment, c_noise_augment = self.scaling(sigma=augment_sigma) + + if cfg_video_cond_bool.condition_on_augment_sigma: + if condition.condition_video_indicator.sum() > 0: + condition.condition_video_augment_sigma = c_noise_augment + else: + condition.condition_video_augment_sigma = torch.zeros_like(c_noise_augment) + + augment_latent_cin = batch_mul(augment_latent, c_in_augment) + _, _, c_in, _ = self.scaling(sigma=sigma) + augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) + + return condition, augment_latent_cin + + def super_denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction: + """ + Performs denoising on the input noise data, noise level, and condition + + Args: + xt (torch.Tensor): The input noise data. + sigma (torch.Tensor): The noise level. + condition (CosmosCondition): conditional information, generated from self.conditioner + + Returns: + DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ + noise prediction (eps_pred) and optional confidence (logvar). + """ + + if getattr(self.config, "use_dummy_temporal_dim", False): + # When using video DiT model for image, we need to use a dummy temporal dimension. + xt = xt.unsqueeze(2) + + xt = xt.to(**self.tensor_kwargs) + sigma = sigma.to(**self.tensor_kwargs) + # get precondition for the network + c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) + + # forward pass through the network + net_output = self.net( + x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + **condition.to_dict(), + ) + + logvar = self.model.logvar(c_noise) + x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) + + # get noise prediction based on sde + eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) + + if getattr(self.config, "use_dummy_temporal_dim", False): + x0_pred = x0_pred.squeeze(2) + eps_pred = eps_pred.squeeze(2) + + return DenoisePrediction(x0_pred, eps_pred, logvar) + + def drop_out_condition_region( + self, augment_latent: Tensor, noise_x: Tensor, cfg_video_cond_bool: VideoCondBoolConfig + ) -> Tensor: + """Drop out the conditional region for CFG on input frames.""" + if cfg_video_cond_bool.cfg_unconditional_type == "zero_condition_region_condition_mask": + augment_latent_drop = torch.zeros_like(augment_latent) + elif cfg_video_cond_bool.cfg_unconditional_type == "noise_x_condition_region": + augment_latent_drop = noise_x + else: + raise NotImplementedError( + f"cfg_unconditional_type {cfg_video_cond_bool.cfg_unconditional_type} not implemented" + ) + return augment_latent_drop + + def denoise( + self, + noise_x: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_video_augment_sigma_in_inference: float = 0.001, + seed_inference: int = 1, + ) -> VideoDenoisePrediction: + """Denoise the noisy input tensor for video data.""" + assert ( + condition.gt_latent is not None + ), "find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition" + gt_latent = condition.gt_latent + cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool + + condition_latent = gt_latent + + if cfg_video_cond_bool.normalize_condition_latent: + condition_latent = self.normalize_condition_latent(condition_latent) + + condition, augment_latent = self.augment_conditional_latent_frames( + condition, + cfg_video_cond_bool, + condition_latent, + condition_video_augment_sigma_in_inference, + sigma, + seed_inference=seed_inference, + ) + condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) + augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) + gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) + + if not condition.video_cond_bool: + augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) + + new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x + denoise_pred = self.super_denoise(new_noise_xt, sigma, condition) + + x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 + if cfg_video_cond_bool.compute_loss_for_condition_region: + x0_pred = denoise_pred.x0 + else: + x0_pred = x0_pred_replaced + + return VideoDenoisePrediction( + x0=x0_pred, + eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), + logvar=denoise_pred.logvar, + net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), + net_x0_pred=denoise_pred.x0, + xt=new_noise_xt, + x0_pred_replaced=x0_pred_replaced, + ) + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + return_noise: bool = False, + ) -> Tensor | Tuple[Tensor, Tensor]: + """ + Generate samples from the batch. Supports condition latent for video generation. + + Args: + data_batch (Dict): Input data batch. + guidance (float): Guidance scale for classifier-free guidance. + seed (int): Random seed for reproducibility. + state_shape (Tuple | None): Shape of the latent state, defaults to self.state_shape if None. + n_sample (int | None): Number of samples to generate, inferred from batch if None. + is_negative_prompt (bool): Use negative prompt for unconditioned generation. + num_steps (int): Number of sampling steps. + condition_latent (torch.Tensor | None): Latent tensor (B,C,T,H,W) as condition for video generation. + num_condition_t (int | None): Number of condition frames in T dimension. + condition_video_augment_sigma_in_inference (float): Sigma for augmenting condition video in inference. + add_input_frames_guidance (bool): Apply guidance to input frames for CFG. + return_noise (bool): Return initial noise along with samples. + + Returns: + Tensor | Tuple[Tensor, Tensor]: Generated samples, or (samples, noise) if return_noise is True. + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + assert condition_latent is not None, "condition_latent should be provided" + + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + seed_inference=seed, + ) + + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), torch.float32, self.tensor_kwargs["device"], seed + ) + * self.sde.sigma_max + ) + if self.net.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + if return_noise: + if self.net.is_context_parallel_enabled: + x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + return samples, x_sigma_max / self.sde.sigma_max + + return samples + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + seed_inference: int = 1, + ) -> Callable: + """ + Generates a callable function `x0_fn` for denoising based on the data batch and condition latent. + + Args: + data_batch (Dict): Input data batch. + guidance (float): Guidance scale. + is_negative_prompt (bool): Use negative prompt for unconditioned generation. + condition_latent (torch.Tensor): Latent tensor (B,C,T,H,W) as condition. + num_condition_t (int | None): Number of condition frames. + condition_video_augment_sigma_in_inference (float): Sigma for condition augmentation. + add_input_frames_guidance (bool): Apply guidance to input frames. + seed_inference (int): Seed for inference noise. + + Returns: + Callable: Function `x0_fn(noise_x, sigma)` returning denoised prediction. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + uncondition = self.add_condition_pose(data_batch, uncondition) + + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + # Should be used for both training and inference. The first and last frame will be condition frames. + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + condition_video_indicator[:, :, -num_condition_t:] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + assert num_condition_t_max >= self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min + num_condition_t = torch.randint( + self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min, + num_condition_t_max + 1, + (1,), + ).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "random": + # Only in training + condition_rate = self.config.conditioner.video_cond_bool.random_conditon_rate + flag = torch.ones(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) * condition_rate + condition_video_indicator = torch.bernoulli(flag).type(latent_dtype).to(latent_state.device) + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def add_condition_pose(self, data_batch: Dict, condition: VideoExtendCondition) -> VideoExtendCondition: + """ + Adds pose condition to the condition object for camera control. + + Args: + data_batch (Dict): Data batch with 'plucker_embeddings' or 'plucker_embeddings_downsample'. + condition (VideoExtendCondition): Condition object to update. + + Returns: + VideoExtendCondition: Updated condition object. + """ + assert ( + "plucker_embeddings" in data_batch or "plucker_embeddings_downsample" in data_batch.keys() + ), f"plucker_embeddings should be in data_batch. only find {data_batch.keys()}" + plucker_embeddings = ( + data_batch["plucker_embeddings"] + if "plucker_embeddings_downsample" not in data_batch.keys() + else data_batch["plucker_embeddings_downsample"] + ) + condition.condition_video_pose = rearrange(plucker_embeddings, "b t c h w -> b c t h w").contiguous() + to_cp = self.net.is_context_parallel_enabled + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition diff --git a/cosmos_predict1/diffusion/module/__init__.py b/cosmos_predict1/diffusion/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/diffusion/module/attention.py b/cosmos_predict1/diffusion/module/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0a4b02dea9adbb9a320ed2631014d4102ed288 --- /dev/null +++ b/cosmos_predict1/diffusion/module/attention.py @@ -0,0 +1,313 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import numpy as np +import torch +import transformer_engine as te +from einops import rearrange +from torch import nn +from torch.utils.checkpoint import checkpoint +from transformer_engine.pytorch.attention import DotProductAttention, apply_rotary_pos_emb + +# ---------------------- Feed Forward Network ----------------------- + + +class FeedForward(nn.Module): + """ + Transformer FFN with optional gating + + Parameters: + d_model (int): Dimensionality of input features. + d_ff (int): Dimensionality of the hidden layer. + dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1. + activation (callable, optional): The activation function applied after the first linear layer. + Defaults to nn.ReLU(). + is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer. + Defaults to False. + bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True. + + Example: + >>> ff = FeedForward(d_model=512, d_ff=2048) + >>> x = torch.randn(64, 10, 512) # Example input tensor + >>> output = ff(x) + >>> print(output.shape) # Expected shape: (64, 10, 512) + """ + + def __init__( + self, + d_model: int, + d_ff: int, + dropout: float = 0.1, + activation=nn.ReLU(), + is_gated: bool = False, + bias: bool = False, + ) -> None: + super().__init__() + + self.layer1 = nn.Linear(d_model, d_ff, bias=bias) + self.layer2 = nn.Linear(d_ff, d_model, bias=bias) + + self.dropout = nn.Dropout(dropout) + self.activation = activation + self.is_gated = is_gated + if is_gated: + self.linear_gate = nn.Linear(d_model, d_ff, bias=False) + + def forward(self, x: torch.Tensor): + g = self.activation(self.layer1(x)) + if self.is_gated: + x = g * self.linear_gate(x) + else: + x = g + assert self.dropout.p == 0.0, "we skip dropout" + return self.layer2(x) + + +class GPT2FeedForward(FeedForward): + def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False): + super().__init__( + d_model=d_model, + d_ff=d_ff, + dropout=dropout, + activation=nn.GELU(), + is_gated=False, + bias=bias, + ) + + def forward(self, x: torch.Tensor): + assert self.dropout.p == 0.0, "we skip dropout" + + x = self.layer1(x) + + def activation_layer2_forward(x): + x = self.activation(x) + x = self.layer2(x) + return x + + x = checkpoint(activation_layer2_forward, x, use_reentrant=False) + return x + + +# ---------------------- Normalization Layer ----------------------- + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor: + """ + Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted. + + Args: + x (torch.Tensor): The input tensor to normalize. + dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first. + eps (float, optional): A small constant to ensure numerical stability during division. + + Returns: + torch.Tensor: The normalized tensor. + """ + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +def get_normalization(name: str, channels: int): + if name == "I": + return nn.Identity() + elif name == "R": + return te.pytorch.RMSNorm(channels, eps=1e-6) + else: + raise ValueError(f"Normalization {name} not found") + + +class BaseAttentionOp(nn.Module): + def __init__(self): + super().__init__() + + +class Attention(nn.Module): + """ + Generalized attention impl. + + Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided. + If `context_dim` is None, self-attention is assumed. + + Parameters: + query_dim (int): Dimension of each query vector. + context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed. + heads (int, optional): Number of attention heads. Defaults to 8. + dim_head (int, optional): Dimension of each head. Defaults to 64. + dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0. + attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default. + qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False. + out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False. + qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections. + Defaults to "SSI". + qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections. + Defaults to 'per_head'. Only support 'per_head'. + + Examples: + >>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1) + >>> query = torch.randn(10, 128) # Batch size of 10 + >>> context = torch.randn(10, 256) # Batch size of 10 + >>> output = attn(query, context) # Perform the attention operation + + Note: + https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/cosmos_predict1/attention.py#L223 + """ + + def __init__( + self, + query_dim: int, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + attn_op: Optional[BaseAttentionOp] = None, + qkv_bias: bool = False, + out_bias: bool = False, + qkv_norm: str = "SSI", + qkv_norm_mode: str = "per_head", + backend: str = "transformer_engine", + qkv_format: str = "bshd", + ) -> None: + super().__init__() + + self.is_selfattn = context_dim is None # self attention + + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + + self.heads = heads + self.dim_head = dim_head + self.qkv_norm_mode = qkv_norm_mode + self.qkv_format = qkv_format + + if self.qkv_norm_mode == "per_head": + norm_dim = dim_head + else: + raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") + + self.backend = backend + self.tp_size = 1 # TP is not included in this Attention implementation. + + self.to_q = nn.Sequential( + nn.Linear(query_dim, inner_dim, bias=qkv_bias), + get_normalization(qkv_norm[0], norm_dim), + ) + self.to_k = nn.Sequential( + nn.Linear(context_dim, inner_dim, bias=qkv_bias), + get_normalization(qkv_norm[1], norm_dim), + ) + self.to_v = nn.Sequential( + nn.Linear(context_dim, inner_dim, bias=qkv_bias), + get_normalization(qkv_norm[2], norm_dim), + ) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim, bias=out_bias), + nn.Dropout(dropout), + ) + + if attn_op: # use what is given + self.attn_op = attn_op + elif self.backend == "transformer_engine": + self.attn_op: BaseAttentionOp = DotProductAttention( + self.heads, + self.dim_head, + num_gqa_groups=self.heads, + attention_dropout=0, + qkv_format=qkv_format, + attn_mask_type="no_mask", + tp_size=self.tp_size, + tp_group=None, + sequence_parallel=False, + ) + elif self.backend == "torch": + self.attn_op = torch.nn.functional.scaled_dot_product_attention + else: + raise ValueError(f"Backend {backend} not found") + self.query_dim = query_dim + self.context_dim = context_dim + self.inner_dim = inner_dim + + def cal_qkv( + self, x, context=None, mask=None, rope_emb=None, **kwargs + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + del kwargs + + """ + self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers. + Before 07/24/2024, these modules normalize across all heads. + After 07/24/2024, to support tensor parallelism and follow the common practice in the community, + we support to normalize per head. + To keep the checkpoint copatibility with the previous code, + we keep the nn.Sequential but call the projection and the normalization layers separately. + We use a flag `self.qkv_norm_mode` to control the normalization behavior. + The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head. + """ + if self.qkv_norm_mode == "per_head": + q = self.to_q[0](x) + context = x if context is None else context + k = self.to_k[0](context) + v = self.to_v[0](context) + q, k, v = map( + lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads, c=self.dim_head), + (q, k, v), + ) + else: + raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") + + q = self.to_q[1](q) + k = self.to_k[1](k) + v = self.to_v[1](v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) + return q, k, v + + def cal_attn(self, q, k, v, mask=None): + if self.backend == "transformer_engine": + seq_dim = self.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V] + return self.to_out(out) + elif self.backend == "torch": + q = rearrange(q, "s b h d -> b h s d") + k = rearrange(k, "s b h d -> b h s d") + v = rearrange(v, "s b h d -> b h s d") + out = self.attn_op(q, k, v) # [B, Mq, H, V] + return self.to_out(rearrange(out, " b h s d -> s b (h d)")) + else: + raise ValueError(f"Backend {self.backend} not found") + + def forward( + self, + x, + context=None, + mask=None, + rope_emb=None, + **kwargs, + ): + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) + return self.cal_attn(q, k, v, mask) diff --git a/cosmos_predict1/diffusion/module/blocks.py b/cosmos_predict1/diffusion/module/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..46b55225b078ec04983af2b3ab61712c6de9258f --- /dev/null +++ b/cosmos_predict1/diffusion/module/blocks.py @@ -0,0 +1,558 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import numpy as np +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from torch import nn + +from cosmos_predict1.diffusion.module.attention import Attention, GPT2FeedForward +from cosmos_predict1.utils import log + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class Timesteps(nn.Module): + def __init__(self, num_channels): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + in_dype = timesteps.dtype + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return emb.to(in_dype) + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): + super().__init__() + log.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) + else: + self.linear_2 = nn.Linear(out_features, out_features, bias=True) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_3D = emb + emb_B_D = sample + else: + emb_B_D = emb + adaln_lora_B_3D = None + + return emb_B_D, adaln_lora_B_3D + + +class FourierFeatures(nn.Module): + """ + Implements a layer that generates Fourier features from input tensors, based on randomly sampled + frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems. + + [B] -> [B, D] + + Parameters: + num_channels (int): The number of Fourier features to generate. + bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1. + normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize + the variance of the features. Defaults to False. + + Example: + >>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True) + >>> x = torch.randn(10, 256) # Example input tensor + >>> output = layer(x) + >>> print(output.shape) # Expected shape: (10, 256) + """ + + def __init__(self, num_channels, bandwidth=1, normalize=False): + super().__init__() + self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True) + self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True) + self.gain = np.sqrt(2) if normalize else 1 + + def forward(self, x, gain: float = 1.0): + """ + Apply the Fourier feature transformation to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1. + + Returns: + torch.Tensor: The transformed tensor, with Fourier features applied. + """ + in_dtype = x.dtype + x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32)) + x = x.cos().mul(self.gain * gain).to(in_dtype) + return x + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size, + temporal_patch_size, + in_channels=3, + out_channels=768, + bias=True, + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + nn.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias + ), + ) + self.out = nn.Identity() + + def forward(self, x): + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return self.out(x) + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size, + spatial_patch_size, + temporal_patch_size, + out_channels, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False) + ) + + def forward( + self, + x_BT_HW_D, + emb_B_D, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_3D is not None + shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( + 2, dim=1 + ) + else: + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + + B = emb_B_D.shape[0] + T = x_BT_HW_D.shape[0] // B + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D) + + x_BT_HW_D = self.linear(x_BT_HW_D) + return x_BT_HW_D + + +class VideoAttn(nn.Module): + """ + Implements video attention with optional cross-attention capabilities. + + This module processes video features while maintaining their spatio-temporal structure. It can perform + self-attention within the video features or cross-attention with external context features. + + Parameters: + x_dim (int): Dimension of input feature vectors + context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention + num_heads (int): Number of attention heads + bias (bool): Whether to include bias in attention projections. Default: False + qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head" + x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD" + n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. + + Input shape: + - x: (T, H, W, B, D) video features + - context (optional): (M, B, D) context features for cross-attention + where: + T: temporal dimension + H: height + W: width + B: batch size + D: feature dimension + M: context sequence length + """ + + def __init__( + self, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + bias: bool = False, + qkv_norm_mode: str = "per_head", + x_format: str = "BTHWD", + n_views: int = 1, + ) -> None: + super().__init__() + self.x_format = x_format + self.n_views = n_views + self.attn = Attention( + x_dim, + context_dim, + num_heads, + x_dim // num_heads, + qkv_bias=bias, + qkv_norm="RRI", + out_bias=bias, + qkv_norm_mode=qkv_norm_mode, + qkv_format="sbhd", + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for video attention. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data. + context (Tensor): Context tensor of shape (B, M, D) or (M, B, D), + where M is the sequence length of the context. + crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms. + rope_emb_L_1_1_D (Optional[Tensor]): + Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. + + Returns: + Tensor: The output tensor with applied attention, maintaining the input shape. + """ + if context is not None and self.n_views > 1: + x_T_H_W_B_D = rearrange(x, "(v t) h w b d -> t h w (v b) d", v=self.n_views) + context_M_B_D = rearrange(context, "(v m) b d -> m (v b) d", v=self.n_views) + else: + x_T_H_W_B_D = x + context_M_B_D = context + T, H, W, B, D = x_T_H_W_B_D.shape + x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d") + x_THW_B_D = self.attn( + x_THW_B_D, + context_M_B_D, + crossattn_mask, + rope_emb=rope_emb_L_1_1_D, + ) + x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) + if context is not None and self.n_views > 1: + x_T_H_W_B_D = rearrange(x_T_H_W_B_D, "t h w (v b) d -> (v t) h w b d", v=self.n_views) + return x_T_H_W_B_D + + +def adaln_norm_state(norm_state, x, scale, shift): + normalized = norm_state(x) + return normalized * (1 + scale) + shift + + +class DITBuildingBlock(nn.Module): + """ + A building block for the DiT (Diffusion Transformer) architecture that supports different types of + attention and MLP operations with adaptive layer normalization. + + Parameters: + block_type (str): Type of block - one of: + - "cross_attn"/"ca": Cross-attention + - "full_attn"/"fa": Full self-attention + - "mlp"/"ff": MLP/feedforward block + x_dim (int): Dimension of input features + context_dim (Optional[int]): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0 + bias (bool): Whether to use bias in layers. Default: False + mlp_dropout (float): Dropout rate for MLP. Default: 0.0 + qkv_norm_mode (str): QKV normalization mode. Default: "per_head" + x_format (str): Input tensor format. Default: "BTHWD" + use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False + adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256 + n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. + """ + + def __init__( + self, + block_type: str, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + mlp_ratio: float = 4.0, + bias: bool = False, + mlp_dropout: float = 0.0, + qkv_norm_mode: str = "per_head", + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + n_views: int = 1, + ) -> None: + block_type = block_type.lower() + + super().__init__() + self.x_format = x_format + if block_type in ["cross_attn", "ca"]: + self.block = VideoAttn( + x_dim, + context_dim, + num_heads, + bias=bias, + qkv_norm_mode=qkv_norm_mode, + x_format=self.x_format, + n_views=n_views, + ) + elif block_type in ["full_attn", "fa"]: + self.block = VideoAttn( + x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format + ) + elif block_type in ["mlp", "ff"]: + self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias) + else: + raise ValueError(f"Unknown block type: {block_type}") + + self.block_type = block_type + self.use_adaln_lora = use_adaln_lora + + self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) + self.n_adaln_chunks = 3 + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(x_dim, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False)) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for dynamically configured blocks with adaptive normalization. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D). + emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation. + crossattn_emb (Tensor): Tensor for cross-attention blocks. + crossattn_mask (Optional[Tensor]): Optional mask for cross-attention. + rope_emb_L_1_1_D (Optional[Tensor]): + Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. + + Returns: + Tensor: The output tensor after processing through the configured block and adaptive normalization. + """ + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = ( + shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + ) + + if self.block_type in ["mlp", "ff"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + ) + elif self.block_type in ["full_attn", "fa"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + context=None, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["cross_attn", "ca"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + context=crossattn_emb, + crossattn_mask=crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + else: + raise ValueError(f"Unknown block type: {self.block_type}") + + return x + + +class GeneralDITTransformerBlock(nn.Module): + """ + A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer. + Each block in the sequence is specified by a block configuration string. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention blocks + num_heads (int): Number of attention heads + block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention, + full-attention, then MLP) + mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0 + x_format (str): Input tensor format. Default: "BTHWD" + use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False + adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256 + n_views (int): Extra parameter used in multi-view diffusion model. Default: 1 + + The block_config string uses "-" to separate block types: + - "ca"/"cross_attn": Cross-attention block + - "fa"/"full_attn": Full self-attention block + - "mlp"/"ff": MLP/feedforward block + + Example: + block_config = "ca-fa-mlp" creates a sequence of: + 1. Cross-attention block + 2. Full self-attention block + 3. MLP block + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + block_config: str, + mlp_ratio: float = 4.0, + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + n_views: int = 1, + ): + super().__init__() + self.blocks = nn.ModuleList() + self.x_format = x_format + for block_type in block_config.split("-"): + self.blocks.append( + DITBuildingBlock( + block_type, + x_dim, + context_dim, + num_heads, + mlp_ratio, + x_format=self.x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + n_views=n_views, + ) + ) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x = x + extra_per_block_pos_emb + for block in self.blocks: + x = block( + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + return x diff --git a/cosmos_predict1/diffusion/module/parallel.py b/cosmos_predict1/diffusion/module/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..9e0837a2c4fec1506f70c8ba2b7c9f602b84c862 --- /dev/null +++ b/cosmos_predict1/diffusion/module/parallel.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import parallel_state +from torch import Tensor +from torch.distributed import ProcessGroup, all_gather, broadcast_object_list, get_process_group_ranks, get_world_size +from torch.distributed.utils import _verify_param_shape_across_processes + +from cosmos_predict1.utils import distributed + + +def split_inputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Split input tensor along the sequence dimension for checkpoint parallelism. + + This function divides the input tensor into equal parts along the specified + sequence dimension, based on the number of ranks in the checkpoint parallelism group. + It then selects the part corresponding to the current rank. + + Args: + x: Input tensor to be split. + seq_dim: The dimension along which to split the input (sequence dimension). + cp_group: The process group for checkpoint parallelism. + + Returns: + A slice of the input tensor corresponding to the current rank. + + Raises: + AssertionError: If the sequence dimension is not divisible by the number of ranks. + """ + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + + assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" + x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) + seq_idx = torch.tensor([cp_group.rank()], device=x.device) + x = x.index_select(seq_dim, seq_idx) + # Note that the new sequence length is the original sequence length / cp_size + x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + return x + + +def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Concatenate outputs from different ranks in the checkpoint parallelism group. + + This function gathers tensors from all ranks in the checkpoint parallelism group + and concatenates them along the specified sequence dimension. + + Args: + x: Input tensor to be concatenated. + seq_dim: The dimension along which to concatenate the tensors (sequence dimension). + cp_group: The process group for checkpoint parallelism. + + Returns: + A tensor that is the concatenation of tensors from all ranks in the cp_group. + + Raises: + RuntimeError: If the gather operation fails. + """ + # Get the world size (number of processes in the group) + world_size = get_world_size(cp_group) + + # Create a list to store tensors from all ranks + gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] + + # Gather tensors from all ranks + try: + all_gather(gathered_tensors, x, group=cp_group) + except RuntimeError as e: + raise RuntimeError(f"Failed to gather tensors: {e}") + + # Concatenate the gathered tensors along the specified dimension + return torch.cat(gathered_tensors, dim=seq_dim) + + +def broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None: + """ + Broadcast the item from the minimum rank in the specified group(s). + Since global rank = tp_rank + cp_rank * tp_size + ... + First broadcast in the tp_group and then in the cp_group will + ensure that the item is broadcasted across ranks in cp_group and tp_group. + + Parameters: + - item: The item to broadcast (can be a torch.Tensor, str, or None). + - to_tp: Whether to broadcast to the tensor model parallel group. + - to_cp: Whether to broadcast to the context parallel group. + """ + if not parallel_state.is_initialized(): + return item + tp_group = parallel_state.get_tensor_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + + to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1 + to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1 + + if to_tp: + min_tp_rank = min(get_process_group_ranks(tp_group)) + + if to_cp: + min_cp_rank = min(get_process_group_ranks(cp_group)) + + if isinstance(item, torch.Tensor): # assume the device is cuda + # log.info(f"{item.shape}", rank0_only=False) + if to_tp: + # torch.distributed.broadcast(item, min_tp_rank, group=tp_group) + item = _robust_broadcast(item, min_tp_rank, tp_group) + if to_cp: + # torch.distributed.broadcast(item, min_cp_rank, group=cp_group) + item = _robust_broadcast(item, min_cp_rank, cp_group) + elif item is not None: + broadcastable_list = [item] + if to_tp: + # log.info(f"{broadcastable_list}", rank0_only=False) + broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group) + if to_cp: + broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group) + + item = broadcastable_list[0] + return item + + +def _robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor: + """ + Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. + + Args: + tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). + src (int): The source rank for the broadcast. Defaults to 0. + + Returns: + torch.Tensor: The broadcasted tensor on all ranks. + """ + # First, broadcast the shape of the tensor + if distributed.get_rank() == src: + shape = torch.tensor(tensor.shape).cuda() + else: + shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() + if is_check_shape: + _verify_param_shape_across_processes(pg, [shape]) + torch.distributed.broadcast(shape, src, group=pg) + + # Resize the tensor on non-src ranks if necessary + if distributed.get_rank() != src: + tensor = tensor.new_empty(shape.tolist()).type_as(tensor) + + # Now broadcast the tensor data + torch.distributed.broadcast(tensor, src, group=pg) + + return tensor diff --git a/cosmos_predict1/diffusion/module/position_embedding.py b/cosmos_predict1/diffusion/module/position_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..4f93ededcb9619eba86cc61139778e18c54a8f86 --- /dev/null +++ b/cosmos_predict1/diffusion/module/position_embedding.py @@ -0,0 +1,497 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np +import torch +from einops import rearrange, repeat +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks + +from cosmos_predict1.diffusion.module.attention import normalize +from cosmos_predict1.diffusion.module.parallel import split_inputs_cp +from cosmos_predict1.diffusion.module.timm import trunc_normal_ + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class VideoPositionEmb(nn.Module): + def __init__(self): + super().__init__() + self.cp_group = None + + def enable_context_parallel(self, cp_group: ProcessGroup): + self.cp_group = cp_group + + def disable_context_parallel(self): + self.cp_group = None + + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: + """ + It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + if self.cp_group is not None: + cp_ranks = get_process_group_ranks(self.cp_group) + cp_size = len(cp_ranks) + B, T, H, W, C = B_T_H_W_C + B_T_H_W_C = (B, T * cp_size, H, W, C) + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) + + if self.cp_group is not None: + if isinstance(self, VideoRopePosition3DEmb): + seq_dim = 0 + else: + seq_dim = 1 + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): + raise NotImplementedError + + +class VideoRopePosition3DEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + uniform_fps = (fps is None) or (fps.min() == fps.max()) + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" + half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) + half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None: # image case + assert T == 1, "T should be 1 for image batch." + half_emb_t = torch.outer(self.seq[:T], temporal_freqs) + else: + half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + + return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + + +class LearnablePosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + self.pos_emb_h = nn.Parameter(torch.zeros(len_h, model_channels)) + self.pos_emb_w = nn.Parameter(torch.zeros(len_w, model_channels)) + self.pos_emb_t = nn.Parameter(torch.zeros(len_t, model_channels)) + + trunc_normal_(self.pos_emb_h, std=0.02) + trunc_normal_(self.pos_emb_w, std=0.02) + trunc_normal_(self.pos_emb_t, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, _ = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:T] + emb = ( + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) + + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + else: + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + return normalize(emb, dim=-1, eps=1e-6) + + +class MultiviewVideoPositionEmb(nn.Module): + def __init__( + self, + ): + super().__init__() + self.cp_group = None + + def enable_context_parallel(self, cp_group: ProcessGroup): + self.cp_group = cp_group + + def disable_context_parallel(self): + self.cp_group = None + + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: + """ + With CP, the function assume that the input tensor is already split. It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + if self.cp_group is not None: + cp_ranks = get_process_group_ranks(self.cp_group) + cp_size = len(cp_ranks) + B, T, H, W, C = B_T_H_W_C + B_T_H_W_C = (B, T * cp_size, H, W, C) + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) + + if self.cp_group is not None: + if isinstance(self, MultiviewVideoRopePosition3DEmb): + seq_dim = 1 + embeddings = rearrange(embeddings, "(V T) H W D -> V (T H W) 1 1 D", V=self.n_views).float() + # rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + embeddings = rearrange(embeddings, "V T 1 1 D -> (V T) 1 1 D", V=self.n_views).float() + else: + seq_dim = 1 + embeddings = rearrange(embeddings, "B (V T) H W C -> (B V) T H W C", V=self.n_views) + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + embeddings = rearrange(embeddings, "(B V) T H W C -> B (V T) H W C", V=self.n_views) + else: + if isinstance(self, MultiviewVideoRopePosition3DEmb): + embeddings = rearrange(embeddings, "t h w d -> (t h w) 1 1 d").float() + + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): + raise NotImplementedError + + +class MultiviewVideoRopePosition3DEmb(MultiviewVideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + n_views: int = 4, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + self.n_views = n_views + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def generate_embedding_for_batch( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + uniform_fps = (fps is None) or (fps.min() == fps.max()) + assert uniform_fps # only support uniform fps now + + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w}) configured for positional embedding. Please adjust the input size or increase the maximum dimensions in the model configuration." + half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) + half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None: # image case + assert T == 1, "T should be 1 for image batch." + half_emb_t = torch.outer(self.seq[:T], temporal_freqs) + else: + half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + + return em_T_H_W_D + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. The camera view dimension is merged in the T dimension + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time * Views, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. + + Returns: + Not specified in the original code snippet. + """ + + B, T, H, W, C = B_T_H_W_C + + single_view_B_T_H_W_C = (B, T // self.n_views, H, W, C) + em_T_H_W_D = torch.cat( + [ + self.generate_embedding_for_batch( + single_view_B_T_H_W_C, + fps=fps, + h_ntk_factor=h_ntk_factor, + w_ntk_factor=w_ntk_factor, + t_ntk_factor=t_ntk_factor, + ) + for item in range(self.n_views) + ], + dim=0, + ) + return em_T_H_W_D + + +class MultiviewSinCosPosEmbAxis(MultiviewVideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + n_views: int = 4, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + self.n_views = n_views + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + dim = model_channels + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + + # rescale pos id is equivalent to rescale frequency + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(len_h) * 1.0 / h_extrapolation_ratio) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(len_w) * 1.0 / w_extrapolation_ratio) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(len_t) * 1.0 / t_extrapolation_ratio) + + self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).float(), persistent=False) + self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).float(), persistent=False) + self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + + single_view_T = T // self.n_views + + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:single_view_T] + emb = torch.cat( + [ + torch.cat( + [ + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W), + repeat(emb_h_H, "h d-> b t h w d", b=B, t=single_view_T, w=W), + repeat(emb_w_W, "w d-> b t h w d", b=B, t=single_view_T, h=H), + ], + dim=-1, + ) + for _ in range(self.n_views) + ], + 1, + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + return emb + + raise ValueError(f"Unknown interpolation method {self.interpolation}") diff --git a/cosmos_predict1/diffusion/module/pretrained_vae.py b/cosmos_predict1/diffusion/module/pretrained_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..c799bfb1341b060afade60958a8162804001f50f --- /dev/null +++ b/cosmos_predict1/diffusion/module/pretrained_vae.py @@ -0,0 +1,611 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod + +import torch +from einops import rearrange +from torch.nn.modules import Module + + +class BaseVAE(torch.nn.Module, ABC): + """ + Abstract base class for a Variational Autoencoder (VAE). + + All subclasses should implement the methods to define the behavior for encoding + and decoding, along with specifying the latent channel size. + """ + + def __init__(self, channel: int = 3, name: str = "vae"): + super().__init__() + self.channel = channel + self.name = name + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the VAE. + """ + return self.channel + + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encodes the input tensor into a latent representation. + + Args: + - state (torch.Tensor): The input tensor to encode. + + Returns: + - torch.Tensor: The encoded latent tensor. + """ + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes the latent representation back to the original space. + + Args: + - latent (torch.Tensor): The latent tensor to decode. + + Returns: + - torch.Tensor: The decoded tensor. + """ + pass + + @property + def spatial_compression_factor(self) -> int: + """ + Returns the spatial reduction factor for the VAE. + """ + raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") + + +class BasePretrainedImageVAE(BaseVAE): + """ + A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values + from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + Derived classes should load pre-trained encoder and decoder components from a remote store + + Attributes: + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + name: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ) -> None: + super().__init__(latent_ch, name) + dtype = torch.bfloat16 if is_bf16 else torch.float32 + self.dtype = dtype + self.is_image = is_image + self.name = name + + def register_mean_std(self, vae_dir: str) -> None: + latent_mean, latent_std = torch.load(os.path.join(vae_dir, "image_mean_std.pt"), weights_only=True) + + target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1] + + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encode the input state to latent space; also handle the dtype conversion, mean and std scaling + """ + in_dtype = state.dtype + latent_mean = self.latent_mean.to(in_dtype) + latent_std = self.latent_std.to(in_dtype) + encoded_state = self.encoder(state.to(self.dtype)) + if isinstance(encoded_state, torch.Tensor): + pass + elif isinstance(encoded_state, tuple): + assert isinstance(encoded_state[0], torch.Tensor) + encoded_state = encoded_state[0] + else: + raise ValueError("Invalid type of encoded state") + return (encoded_state.to(in_dtype) - latent_mean) / latent_std + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decode the input latent to state; also handle the dtype conversion, mean and std scaling + """ + in_dtype = latent.dtype + latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype) + return self.decoder(latent.to(self.dtype)).to(in_dtype) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.encoder.to(self.dtype) + + +class JITVAE(BasePretrainedImageVAE): + """ + A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder + and decoder components from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The JIT compiled encoder loaded from storage. + decoder (Module): The JIT compiled decoder loaded from storage. + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + name (str): Name of the model, used for differentiating cache file paths. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + name: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ): + super().__init__(name, latent_ch, is_image, is_bf16) + + def load_encoder(self, vae_dir: str) -> None: + """ + Load the encoder from the remote store. + """ + self.encoder = torch.jit.load(os.path.join(vae_dir, "encoder.jit")) + + self.encoder.eval() + for param in self.encoder.parameters(): + param.requires_grad = False + self.encoder.to(self.dtype) + + def load_decoder(self, vae_dir: str) -> None: + """ + Load the decoder from the remote store. + """ + self.decoder = torch.jit.load(os.path.join(vae_dir, "decoder.jit")) + + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + +class BaseVAE(torch.nn.Module, ABC): + """ + Abstract base class for a Variational Autoencoder (VAE). + + All subclasses should implement the methods to define the behavior for encoding + and decoding, along with specifying the latent channel size. + """ + + def __init__(self, channel: int = 3, name: str = "vae"): + super().__init__() + self.channel = channel + self.name = name + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the VAE. + """ + return self.channel + + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encodes the input tensor into a latent representation. + + Args: + - state (torch.Tensor): The input tensor to encode. + + Returns: + - torch.Tensor: The encoded latent tensor. + """ + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes the latent representation back to the original space. + + Args: + - latent (torch.Tensor): The latent tensor to decode. + + Returns: + - torch.Tensor: The decoded tensor. + """ + pass + + @property + def spatial_compression_factor(self) -> int: + """ + Returns the spatial reduction factor for the VAE. + """ + raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") + + +class VideoTokenizerInterface(ABC): + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + pass + + @abstractmethod + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + pass + + @property + @abstractmethod + def spatial_compression_factor(self): + pass + + @property + @abstractmethod + def temporal_compression_factor(self): + pass + + @property + @abstractmethod + def spatial_resolution(self): + pass + + @property + @abstractmethod + def pixel_chunk_duration(self): + pass + + @property + @abstractmethod + def latent_chunk_duration(self): + pass + + +class BasePretrainedVideoTokenizer(ABC): + """ + Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing. + + Args: + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing. + max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow. + max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow. + + The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`) + which define how video data is subdivided and compressed during the encoding and decoding processes. The + `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory + constraints. + """ + + def __init__( + self, + pixel_chunk_duration: int = 17, + temporal_compress_factor: int = 8, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + ): + self._pixel_chunk_duration = pixel_chunk_duration + self._temporal_compress_factor = temporal_compress_factor + self.max_enc_batch_size = max_enc_batch_size + self.max_dec_batch_size = max_dec_batch_size + + def register_mean_std(self, vae_dir: str) -> None: + latent_mean, latent_std = torch.load(os.path.join(vae_dir, "mean_std.pt"), weights_only=True) + + latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + + target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1] + + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: + """ + Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding + """ + B, C, T, H, W = state.shape + assert ( + T % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" + return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration) + + def transform_decode_state_shape(self, latent: torch.Tensor) -> torch.Tensor: + B, _, T, _, _ = latent.shape + assert ( + T % self.latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" + return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = state.shape + state = rearrange(state, "b c t h w -> (b t) c 1 h w") + B, C, T, H, W = state.shape + state = self.transform_encode_state_shape(state) + # use max_enc_batch_size to avoid OOM + if state.shape[0] > self.max_enc_batch_size: + latent = [] + for i in range(0, state.shape[0], self.max_enc_batch_size): + latent.append(super().encode(state[i : i + self.max_enc_batch_size])) + latent = torch.cat(latent, dim=0) + else: + latent = super().encode(state) + + latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T) + return latent + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode, + it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions. + + It can also decode single frame image data. + + Args: + latent (torch.Tensor): The latent space tensor containing encoded video data. + + Returns: + torch.Tensor: The decoded video tensor reconstructed from latent space. + """ + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = latent.shape + latent = rearrange(latent, "b c t h w -> (b t) c 1 h w") + B, _, T, _, _ = latent.shape + latent = self.transform_decode_state_shape(latent) + # use max_enc_batch_size to avoid OOM + if latent.shape[0] > self.max_dec_batch_size: + state = [] + for i in range(0, latent.shape[0], self.max_dec_batch_size): + state.append(super().decode(latent[i : i + self.max_dec_batch_size])) + state = torch.cat(state, dim=0) + else: + state = super().decode(latent) + assert state.shape[2] == self.pixel_chunk_duration + state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T) + return state + + @property + def pixel_chunk_duration(self) -> int: + return self._pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + # return self._latent_chunk_duration + assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, ( + f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration " + f"{self.latent_chunk_duration}" + ) + return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + + @property + def temporal_compression_factor(self): + return self._temporal_compress_factor + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" + return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_chunk_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" + return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration + + +class VideoJITTokenizer(BasePretrainedVideoTokenizer, JITVAE, VideoTokenizerInterface): + """ + Instance of BasePretrainedVideoVAE that loads encoder and decoder from JIT scripted module file + """ + + def __init__( + self, + name: str, + latent_ch: int = 16, + is_bf16: bool = True, + spatial_compression_factor: int = 16, + temporal_compression_factor: int = 8, + pixel_chunk_duration: int = 17, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + spatial_resolution: str = "720", + ): + super().__init__( + pixel_chunk_duration, + temporal_compression_factor, + max_enc_batch_size, + max_dec_batch_size, + ) + super(BasePretrainedVideoTokenizer, self).__init__( + name, + latent_ch, + False, + is_bf16, + ) + + self._spatial_compression_factor = spatial_compression_factor + self._spatial_resolution = spatial_resolution + + @property + def spatial_compression_factor(self): + return self._spatial_compression_factor + + @property + def spatial_resolution(self) -> str: + return self._spatial_resolution + + +class JointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface): + def __init__( + self, + image_vae: torch.nn.Module, + video_vae: torch.nn.Module, + name: str, + latent_ch: int = 16, + squeeze_for_image: bool = True, + ): + super().__init__(latent_ch, name) + self.image_vae = image_vae + self.video_vae = video_vae + self.squeeze_for_image = squeeze_for_image + + def encode_image(self, state: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.encode(state.squeeze(2)).unsqueeze(2) + return self.image_vae.encode(state) + + def decode_image(self, latent: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.decode(latent.squeeze(2)).unsqueeze(2) + return self.image_vae.decode(latent) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + return self.encode_image(state) + + return self.video_vae.encode(state) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = latent.shape + if T == 1: + return self.decode_image(latent) + return self.video_vae.decode(latent) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.video_vae.reset_dtype() + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + return self.video_vae.get_latent_num_frames(num_pixel_frames) + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + return self.video_vae.get_pixel_num_frames(num_latent_frames) + + @property + def spatial_compression_factor(self): + return self.video_vae.spatial_compression_factor + + @property + def temporal_compression_factor(self): + return self.video_vae.temporal_compression_factor + + @property + def spatial_resolution(self) -> str: + return self.video_vae.spatial_resolution + + @property + def pixel_chunk_duration(self) -> int: + return self.video_vae.pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + return self.video_vae.latent_chunk_duration + + +class JointImageVideoSharedJITTokenizer(JointImageVideoTokenizer): + """ + First version of the ImageVideoVAE trained with Fitsum. + We have to use seperate mean and std for image and video due to non-causal nature of the model. + """ + + def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16): + super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False) + assert isinstance(image_vae, JITVAE) + assert isinstance( + video_vae, VideoJITTokenizer + ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}" + # a hack to make the image_vae and video_vae share the same encoder and decoder + + def load_weights(self, vae_dir: str): + # Load for video_vae + self.video_vae.register_mean_std(vae_dir) + self.video_vae.load_decoder(vae_dir) + self.video_vae.load_encoder(vae_dir) + + # Load for image_vae + self.image_vae.register_mean_std(vae_dir) + self.image_vae.load_decoder(vae_dir) + self.image_vae.load_encoder(vae_dir) diff --git a/cosmos_predict1/diffusion/module/timm.py b/cosmos_predict1/diffusion/module/timm.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e5b1fdd15cc11f0aad45aaecbd5c78c5f27ce1 --- /dev/null +++ b/cosmos_predict1/diffusion/module/timm.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings + +import torch +import torch.nn as nn + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.activation = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/cosmos_predict1/diffusion/modules/denoiser_scaling.py b/cosmos_predict1/diffusion/modules/denoiser_scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..f4fb3df0f38d52de317177c248e22707a899beb4 --- /dev/null +++ b/cosmos_predict1/diffusion/modules/denoiser_scaling.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch + + +class EDMScaling: + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise diff --git a/cosmos_predict1/diffusion/modules/res_sampler.py b/cosmos_predict1/diffusion/modules/res_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..59a4952f69eb1bb1a708ee561941577cca7e5d75 --- /dev/null +++ b/cosmos_predict1/diffusion/modules/res_sampler.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general framework for various sampling algorithm from a diffusion model. +Impl based on +* Refined Exponential Solver (RES) in https://arxiv.org/pdf/2308.02157 +* also clude other impl, DDIM, DEIS, DPM-Solver, EDM sampler. +Most of sampling algorihtm, Runge-Kutta, Multi-step, etc, can be impl in this framework by \ + adding new step function in get_runge_kutta_fn or get_multi_step_fn. +""" + +import math +from typing import Any, Callable, List, Literal, Optional, Tuple, Union + +import attrs +import torch + +from cosmos_predict1.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported +from cosmos_predict1.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported +from cosmos_predict1.utils.config import make_freezable + +COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"] + + +@make_freezable +@attrs.define(slots=False) +class SolverConfig: + is_multi: bool = False + rk: str = "2mid" + multistep: str = "2ab" + # following parameters control stochasticity, see EDM paper + # BY default, we use deterministic with no stochasticity + s_churn: float = 0.0 + s_t_max: float = float("inf") + s_t_min: float = 0.05 + s_noise: float = 1.0 + + +@make_freezable +@attrs.define(slots=False) +class SolverTimestampConfig: + nfe: int = 50 + t_min: float = 0.002 + t_max: float = 80.0 + order: float = 7.0 + is_forward: bool = False # whether generate forward or backward timestamps + + +@make_freezable +@attrs.define(slots=False) +class SamplerConfig: + solver: SolverConfig = attrs.field(factory=SolverConfig) + timestamps: SolverTimestampConfig = attrs.field(factory=SolverTimestampConfig) + sample_clean: bool = True # whether run one last step to generate clean image + + +def get_rev_ts( + t_min: float, t_max: float, num_steps: int, ts_order: Union[int, float], is_forward: bool = False +) -> torch.Tensor: + """ + Generate a sequence of reverse time steps. + + Args: + t_min (float): The minimum time value. + t_max (float): The maximum time value. + num_steps (int): The number of time steps to generate. + ts_order (Union[int, float]): The order of the time step progression. + is_forward (bool, optional): If True, returns the sequence in forward order. Defaults to False. + + Returns: + torch.Tensor: A tensor containing the generated time steps in reverse or forward order. + + Raises: + ValueError: If `t_min` is not less than `t_max`. + TypeError: If `ts_order` is not an integer or float. + """ + if t_min >= t_max: + raise ValueError("t_min must be less than t_max") + + if not isinstance(ts_order, (int, float)): + raise TypeError("ts_order must be an integer or float") + + step_indices = torch.arange(num_steps + 1, dtype=torch.float64) + time_steps = ( + t_max ** (1 / ts_order) + step_indices / num_steps * (t_min ** (1 / ts_order) - t_max ** (1 / ts_order)) + ) ** ts_order + + if is_forward: + return time_steps.flip(dims=(0,)) + + return time_steps + + +class Sampler(torch.nn.Module): + def __init__(self, cfg: Optional[SamplerConfig] = None): + super().__init__() + if cfg is None: + cfg = SamplerConfig() + self.cfg = cfg + + @torch.no_grad() + def forward( + self, + x0_fn: Callable, + x_sigma_max: torch.Tensor, + num_steps: int = 35, + sigma_min: float = 0.002, + sigma_max: float = 80, + rho: float = 7, + S_churn: float = 0, + S_min: float = 0, + S_max: float = float("inf"), + S_noise: float = 1, + solver_option: str = "2ab", + ) -> torch.Tensor: + in_dtype = x_sigma_max.dtype + + def float64_x0_fn(x_B_StateShape: torch.Tensor, t_B: torch.Tensor) -> torch.Tensor: + return x0_fn(x_B_StateShape.to(in_dtype), t_B.to(in_dtype)).to(torch.float64) + + is_multistep = is_multi_step_fn_supported(solver_option) + is_rk = is_runge_kutta_fn_supported(solver_option) + assert is_multistep or is_rk, f"Only support multistep or Runge-Kutta method, got {solver_option}" + + solver_cfg = SolverConfig( + s_churn=S_churn, + s_t_max=S_max, + s_t_min=S_min, + s_noise=S_noise, + is_multi=is_multistep, + rk=solver_option, + multistep=solver_option, + ) + timestamps_cfg = SolverTimestampConfig(nfe=num_steps, t_min=sigma_min, t_max=sigma_max, order=rho) + sampler_cfg = SamplerConfig(solver=solver_cfg, timestamps=timestamps_cfg, sample_clean=True) + + return self._forward_impl(float64_x0_fn, x_sigma_max, sampler_cfg).to(in_dtype) + + @torch.no_grad() + def _forward_impl( + self, + denoiser_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + noisy_input_B_StateShape: torch.Tensor, + sampler_cfg: Optional[SamplerConfig] = None, + callback_fns: Optional[List[Callable]] = None, + ) -> torch.Tensor: + """ + Internal implementation of the forward pass. + + Args: + denoiser_fn: Function to denoise the input. + noisy_input_B_StateShape: Input tensor with noise. + sampler_cfg: Configuration for the sampler. + callback_fns: List of callback functions to be called during sampling. + + Returns: + torch.Tensor: Denoised output tensor. + """ + sampler_cfg = self.cfg if sampler_cfg is None else sampler_cfg + solver_order = 1 if sampler_cfg.solver.is_multi else int(sampler_cfg.solver.rk[0]) + num_timestamps = sampler_cfg.timestamps.nfe // solver_order + + sigmas_L = get_rev_ts( + sampler_cfg.timestamps.t_min, sampler_cfg.timestamps.t_max, num_timestamps, sampler_cfg.timestamps.order + ).to(noisy_input_B_StateShape.device) + + denoised_output = differential_equation_solver( + denoiser_fn, sigmas_L, sampler_cfg.solver, callback_fns=callback_fns + )(noisy_input_B_StateShape) + + if sampler_cfg.sample_clean: + # Override denoised_output with fully denoised version + ones = torch.ones(denoised_output.size(0), device=denoised_output.device, dtype=denoised_output.dtype) + denoised_output = denoiser_fn(denoised_output, sigmas_L[-1] * ones) + + return denoised_output + + +def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any) -> Any: + """ + Implements a for loop with a function. + + Args: + lower: Lower bound of the loop (inclusive). + upper: Upper bound of the loop (exclusive). + body_fun: Function to be applied in each iteration. + init_val: Initial value for the loop. + + Returns: + The final result after all iterations. + """ + val = init_val + for i in range(lower, upper): + val = body_fun(i, val) + return val + + +def differential_equation_solver( + x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + sigmas_L: torch.Tensor, + solver_cfg: SolverConfig, + callback_fns: Optional[List[Callable]] = None, +) -> Callable[[torch.Tensor], torch.Tensor]: + """ + Creates a differential equation solver function. + + Args: + x0_fn: Function to compute x0 prediction. + sigmas_L: Tensor of sigma values with shape [L,]. + solver_cfg: Configuration for the solver. + callback_fns: Optional list of callback functions. + + Returns: + A function that solves the differential equation. + """ + num_step = len(sigmas_L) - 1 + + if solver_cfg.is_multi: + update_step_fn = get_multi_step_fn(solver_cfg.multistep) + else: + update_step_fn = get_runge_kutta_fn(solver_cfg.rk) + + eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1) + + def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor: + """ + Samples from the differential equation. + + Args: + input_xT_B_StateShape: Input tensor with shape [B, StateShape]. + + Returns: + Output tensor with shape [B, StateShape]. + """ + ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float64) + + def step_fn( + i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]] + ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: + input_x_B_StateShape, x0_preds = state + sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1] + + # algorithm 2: line 4-6 + if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max: + hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0 + input_x_B_StateShape = input_x_B_StateShape + ( + hat_sigma_cur_0**2 - sigma_cur_0**2 + ).sqrt() * solver_cfg.s_noise * torch.randn_like(input_x_B_StateShape) + sigma_cur_0 = hat_sigma_cur_0 + + if solver_cfg.is_multi: + x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) + output_x_B_StateShape, x0_preds = update_step_fn( + input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds + ) + else: + output_x_B_StateShape, x0_preds = update_step_fn( + input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn + ) + + if callback_fns: + for callback_fn in callback_fns: + callback_fn(**locals()) + + return output_x_B_StateShape, x0_preds + + x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None]) + return x_at_eps + + return sample_fn diff --git a/cosmos_predict1/diffusion/networks/__init__.py b/cosmos_predict1/diffusion/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/diffusion/networks/general_dit.py b/cosmos_predict1/diffusion/networks/general_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..c5031e1c189557b7c4cdff7dcdebedf196771d29 --- /dev/null +++ b/cosmos_predict1/diffusion/networks/general_dit.py @@ -0,0 +1,569 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. +""" + +from typing import List, Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks +from torchvision import transforms + +from cosmos_predict1.diffusion.conditioner import DataType +from cosmos_predict1.diffusion.module.attention import get_normalization +from cosmos_predict1.diffusion.module.blocks import ( + FinalLayer, + GeneralDITTransformerBlock, + PatchEmbed, + TimestepEmbedding, + Timesteps, +) +from cosmos_predict1.diffusion.module.position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb +from cosmos_predict1.utils import log + + +class GeneralDIT(nn.Module): + """ + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + block_config (str): Configuration of the transformer block. See Notes for supported block types. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD'). + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + use_cross_attn_mask (bool): Whether to use mask in cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + affline_emb_norm (bool): Whether to normalize affine embeddings. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + + Notes: + Supported block types in block_config: + * cross_attn, ca: Cross attention + * full_attn: Full attention on all flattened tokens + * mlp, ff: Feed forward block + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + block_config: str = "FA-CA-MLP", + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + block_x_format: str = "BTHWD", + # cross attention settings + crossattn_emb_channels: int = 1024, + use_cross_attn_mask: bool = False, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + affline_emb_norm: bool = False, # whether or not to normalize the affine embedding + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = True, + extra_per_block_abs_pos_emb_type: str = "learnable", + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + ) -> None: + super().__init__() + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.use_cross_attn_mask = use_cross_attn_mask + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.affline_emb_norm = affline_emb_norm + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower() + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + + self.build_patch_embed() + self.build_pos_embed() + self.cp_group = None + self.block_x_format = block_x_format + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), + ) + + self.blocks = nn.ModuleDict() + + for idx in range(num_blocks): + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + block_config=block_config, + mlp_ratio=mlp_ratio, + x_format=self.block_x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + ) + + self.build_decode_head() + if self.affline_emb_norm: + log.debug("Building affine embedding normalization layer") + self.affline_norm = get_normalization("R", model_channels) + else: + self.affline_norm = nn.Identity() + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding + nn.init.normal_(self.t_embedder[1].linear_1.weight, std=0.02) + if self.t_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.t_embedder[1].linear_2.weight, std=0.02) + if self.t_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_2.bias, 0) + + # Zero-out adaLN modulation layers in DiT blocks: + for transformer_block in self.blocks.values(): + for block in transformer_block.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + if block.adaLN_modulation[-1].bias is not None: + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + def build_decode_head(self): + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + ) + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + ) + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + ) + + def build_pos_embed(self): + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + log.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + assert self.extra_per_block_abs_pos_emb is True, "extra_per_block_abs_pos_emb must be True" + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "learnable", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + self.extra_pos_embedder = LearnablePosEmbAxis(**kwargs) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def decoder_head( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] + crossattn_mask: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + del crossattn_emb, crossattn_mask + B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape + x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D") + x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D) + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + return x_B_D_T_H_W + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to + augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + we need forward_before_blocks pass to the forward_before_blocks function. + """ + + inputs = self.forward_before_blocks( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( + inputs["x"], + inputs["affline_emb_B_D"], + inputs["crossattn_emb"], + inputs["crossattn_mask"], + inputs["rope_emb_L_1_1_D"], + inputs["adaln_lora_B_3D"], + inputs["original_shape"], + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + + for _, block in self.blocks.items(): + assert ( + self.blocks["block0"].x_format == block.x_format + ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + + x = block( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + ) + + x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") + + x_B_D_T_H_W = self.decoder_head( + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_D=affline_emb_B_D, + crossattn_emb=None, + origin_shape=original_shape, + crossattn_mask=None, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + return x_B_D_T_H_W + + def enable_context_parallel(self, cp_group: ProcessGroup): + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + # Set these attributes for spliting the data after embedding. + self.cp_group = cp_group + # Set these attributes for computing the loss. + self.cp_size = cp_size + + self.pos_embedder.enable_context_parallel(cp_group) + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.enable_context_parallel(cp_group) + # Loop through the model to set up context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff", "cross_attn", "ca"]: + continue + elif layer.block.attn.backend == "transformer_engine": + layer.block.attn.attn_op.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + + log.debug(f"[CP] Enable context parallelism with size {cp_size}") + + def disable_context_parallel(self): + self.cp_group = None + self.cp_size = None + + self.pos_embedder.disable_context_parallel() + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.disable_context_parallel() + + # Loop through the model to disable context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff"]: + continue + elif layer.block_type in ["cross_attn", "ca"]: + continue + else: + layer.block.attn.attn_op.cp_group = None + layer.block.attn.attn_op.cp_ranks = None + layer.block.attn.attn_op.cp_stream = None + + log.debug("[CP] Disable context parallelism.") + + @property + def is_context_parallel_enabled(self): + return self.cp_group is not None diff --git a/cosmos_predict1/diffusion/networks/general_dit_multiview.py b/cosmos_predict1/diffusion/networks/general_dit_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..68ddbdb473365e6f4658a36d463f1e57479556e1 --- /dev/null +++ b/cosmos_predict1/diffusion/networks/general_dit_multiview.py @@ -0,0 +1,396 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from cosmos_predict1.diffusion.conditioner import DataType +from cosmos_predict1.diffusion.module.blocks import GeneralDITTransformerBlock, PatchEmbed +from cosmos_predict1.diffusion.module.parallel import split_inputs_cp +from cosmos_predict1.diffusion.module.position_embedding import ( + MultiviewSinCosPosEmbAxis, + MultiviewVideoRopePosition3DEmb, +) +from cosmos_predict1.diffusion.networks.general_dit import GeneralDIT +from cosmos_predict1.utils import log + + +class MultiviewGeneralDIT(GeneralDIT): + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + block_config: str = "FA-CA-MLP", + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + block_x_format: str = "BTHWD", + # cross attention settings + crossattn_emb_channels: int = 1024, + use_cross_attn_mask: bool = False, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + affline_emb_norm: bool = False, # whether or not to normalize the affine embedding + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = True, + extra_per_block_abs_pos_emb_type: str = "sincos", + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + n_views: int = 3, + view_condition_dim: int = 3, + traj_condition_dim: int = 0, + concat_view_embedding: bool = True, + concat_traj_embedding: bool = False, + add_repeat_frame_embedding: bool = False, + ): + self.n_views = n_views + self.view_condition_dim = view_condition_dim + self.concat_view_embedding = concat_view_embedding + self.traj_condition_dim = traj_condition_dim + self.concat_traj_embedding = concat_traj_embedding + self.add_repeat_frame_embedding = add_repeat_frame_embedding + super().__init__( + max_img_h, + max_img_w, + max_frames, + in_channels, + out_channels, + patch_spatial, + patch_temporal, + concat_padding_mask, + block_config, + model_channels, + num_blocks, + num_heads, + mlp_ratio, + block_x_format, + crossattn_emb_channels, + use_cross_attn_mask, + pos_emb_cls, + pos_emb_learnable, + pos_emb_interpolation, + affline_emb_norm, # whether or not to normalize the affine embedding + use_adaln_lora, + adaln_lora_dim, + rope_h_extrapolation_ratio, + rope_w_extrapolation_ratio, + rope_t_extrapolation_ratio, + extra_per_block_abs_pos_emb, + extra_per_block_abs_pos_emb_type, + extra_h_extrapolation_ratio, + extra_w_extrapolation_ratio, + extra_t_extrapolation_ratio, + ) + # reinit self.blocks + del self.blocks + self.blocks = nn.ModuleDict() + for idx in range(self.num_blocks): + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + block_config=block_config, + mlp_ratio=mlp_ratio, + x_format=self.block_x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + n_views=self.n_views, + ) + self.view_embeddings = nn.Embedding(n_views, view_condition_dim) # Learnable embedding layer + if self.concat_traj_embedding: + self.traj_embeddings = nn.Linear(192, self.traj_condition_dim) # Learnable embedding layer + if self.add_repeat_frame_embedding: + self.repeat_frame_embedding = nn.Linear(1, view_condition_dim) # Learnable embedding layer + + self.initialize_weights() + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + view_condition_dim, + traj_condition_dim, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + self.view_condition_dim, + self.traj_condition_dim, + ) + if self.concat_view_embedding: + in_channels = in_channels + view_condition_dim if view_condition_dim > 0 else in_channels + + if self.concat_traj_embedding: + in_channels = in_channels + traj_condition_dim if traj_condition_dim > 0 else in_channels + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + ) + + def build_pos_embed(self): + if self.pos_emb_cls == "rope3d": + cls_type = MultiviewVideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=30, + min_fps=1, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + n_views=self.n_views, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "sincos", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + self.extra_pos_embedder = MultiviewSinCosPosEmbAxis(**kwargs) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + trajectory = kwargs.get("trajectory", None) + frame_repeat = kwargs.get("frame_repeat", None) + + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + trajectory=trajectory, + frame_repeat=frame_repeat, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + trajectory: Optional[torch.Tensor] = None, + frame_repeat: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + + view_indices = torch.arange(self.n_views).to(x_B_C_T_H_W.device) # View indices [0, 1, ..., V-1] + view_embedding = self.view_embeddings(view_indices) # Shape: [V, embedding_dim] + view_embedding = rearrange(view_embedding, "V D -> D V") + view_embedding = view_embedding.unsqueeze(0).unsqueeze(3).unsqueeze(4).unsqueeze(5) # Shape: [1, D, V, 1, 1, 1] + + if self.add_repeat_frame_embedding: + if frame_repeat is None: + frame_repeat = ( + torch.zeros([x_B_C_T_H_W.shape[0], view_embedding.shape[1]]) + .to(view_embedding.device) + .to(view_embedding.dtype) + ) + frame_repeat_embedding = self.repeat_frame_embedding(frame_repeat.unsqueeze(-1)) + frame_repeat_embedding = rearrange(frame_repeat_embedding, "B V D -> B D V") + view_embedding = view_embedding + frame_repeat_embedding.unsqueeze(3).unsqueeze(4).unsqueeze(5) + + x_B_C_V_T_H_W = rearrange(x_B_C_T_H_W, "B C (V T) H W -> B C V T H W", V=self.n_views) + view_embedding = view_embedding.expand( + x_B_C_V_T_H_W.shape[0], + view_embedding.shape[1], + view_embedding.shape[2], + x_B_C_V_T_H_W.shape[3], + x_B_C_V_T_H_W.shape[4], + x_B_C_V_T_H_W.shape[5], + ) # Shape: [B, V, 3, t, H, W] + if self.concat_traj_embedding: + traj_emb = self.traj_embeddings(trajectory) + traj_emb = traj_emb.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5) + traj_emb = traj_emb.expand( + x_B_C_V_T_H_W.shape[0], + traj_emb.shape[1], + view_embedding.shape[2], + x_B_C_V_T_H_W.shape[3], + x_B_C_V_T_H_W.shape[4], + x_B_C_V_T_H_W.shape[5], + ) # Shape: [B, V, 3, t, H, W] + + x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding, traj_emb], dim=1) + else: + x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding], dim=1) + + x_B_C_T_H_W = rearrange(x_B_C_V_T_H_W, " B C V T H W -> B C (V T) H W", V=self.n_views) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb diff --git a/cosmos_predict1/diffusion/networks/general_dit_video_conditioned.py b/cosmos_predict1/diffusion/networks/general_dit_video_conditioned.py new file mode 100644 index 0000000000000000000000000000000000000000..e19f20cc9cf95a1ba0e474d9d6fe1bffd413c18e --- /dev/null +++ b/cosmos_predict1/diffusion/networks/general_dit_video_conditioned.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from einops import rearrange +from torch import nn + +from cosmos_predict1.diffusion.conditioner import DataType +from cosmos_predict1.diffusion.module.blocks import TimestepEmbedding, Timesteps +from cosmos_predict1.diffusion.module.parallel import split_inputs_cp +from cosmos_predict1.diffusion.networks.general_dit import GeneralDIT +from cosmos_predict1.utils import log + + +class VideoExtendGeneralDIT(GeneralDIT): + def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): + self.add_augment_sigma_embedding = add_augment_sigma_embedding + + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels, **kwargs) + log.debug(f"VideoExtendGeneralDIT in_channels: {in_channels}") + + def build_additional_timestamp_embedder(self): + super().build_additional_timestamp_embedder() + if self.add_augment_sigma_embedding: + log.info("Adding augment sigma embedding") + self.augment_sigma_embedder = nn.Sequential( + Timesteps(self.model_channels), + TimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), + ) + + def initialize_weights(self): + if self.add_augment_sigma_embedding: + # Initialize timestep embedding for augment sigma + nn.init.normal_(self.augment_sigma_embedder[1].linear_1.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.augment_sigma_embedder[1].linear_2.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_2.bias, 0) + + super().initialize_weights() # Call this last since it wil call TP weight init + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Forward pass of the video-conditioned DIT model. + + Args: + x: Input tensor of shape (B, C, T, H, W) + timesteps: Timestep tensor of shape (B,) + crossattn_emb: Cross attention embeddings of shape (B, N, D) + crossattn_mask: Optional cross attention mask of shape (B, N) + fps: Optional frames per second tensor + image_size: Optional image size tensor + padding_mask: Optional padding mask tensor + scalar_feature: Optional scalar features tensor + data_type: Type of data being processed (default: DataType.VIDEO) + video_cond_bool: Optional video conditioning boolean tensor + condition_video_indicator: Optional video condition indicator tensor + condition_video_input_mask: Required mask tensor for video data type + condition_video_augment_sigma: Optional sigma values for conditional input augmentation + **kwargs: Additional keyword arguments + + Returns: + torch.Tensor: Output tensor + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert condition_video_input_mask is not None, "condition_video_input_mask is required for video data type" + + if self.cp_group is not None: + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=2, cp_group=self.cp_group + ) + condition_video_indicator = split_inputs_cp( + condition_video_indicator, seq_dim=2, cp_group=self.cp_group + ) + if condition_video_pose is not None: + condition_video_pose = split_inputs_cp(condition_video_pose, seq_dim=2, cp_group=self.cp_group) + + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + + condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + + if self.add_augment_sigma_embedding: + if condition_video_augment_sigma is None: + # Handling image case + # Note: for video case, when there is not condition frames, we also set it as zero, see extend_model augment_conditional_latent_frames function + assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" + condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) + + affline_augment_sigma_emb_B_D, _ = self.augment_sigma_embedder(condition_video_augment_sigma.flatten()) + affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output diff --git a/cosmos_predict1/diffusion/networks/general_dit_video_conditioned_multiview.py b/cosmos_predict1/diffusion/networks/general_dit_video_conditioned_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..702d72812b80d2327700203d021b80601a51b944 --- /dev/null +++ b/cosmos_predict1/diffusion/networks/general_dit_video_conditioned_multiview.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from cosmos_predict1.diffusion.conditioner import DataType +from cosmos_predict1.diffusion.module.parallel import split_inputs_cp +from cosmos_predict1.diffusion.networks.general_dit_multiview import MultiviewGeneralDIT +from cosmos_predict1.utils import log + + +class MultiviewVideoExtendGeneralDIT(MultiviewGeneralDIT): + def __init__(self, *args, in_channels, **kwargs): + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels + 1, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels + 1}") + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views + ) + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=3, cp_group=self.cp_group + ) + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views + ) + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) diff --git a/cosmos_predict1/diffusion/prompt_upsampler/inference.py b/cosmos_predict1/diffusion/prompt_upsampler/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..494c9b882cafdbd7a7969966e8a51a97a759d1a1 --- /dev/null +++ b/cosmos_predict1/diffusion/prompt_upsampler/inference.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, TypedDict + +import torch + +from cosmos_predict1.autoregressive.model import AutoRegressiveModel +from cosmos_predict1.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer +from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer + + +class ChatPrediction(TypedDict, total=False): + tokens: List[str] # not required + logprobs: List[float] # not required + + +def chat_completion( + model: AutoRegressiveModel, + dialogs: List, + seed: int = None, + temperature: float = 0.01, + top_k: int = None, + top_p: float = None, + max_gen_len: Optional[int] = None, + num_gen_seq: int = 1, + logprobs: bool = False, + generation_prefix: str = "", + compile_sampling: bool = False, + compile_prefill: bool = False, + stop_tokens=None, + verbose: bool = False, +) -> List[ChatPrediction]: + """ + Generate assistant responses for a list of conversational dialogs using the language generation model. + + Args: + model (AutoRegressiveModel): The language generation model. + dialogs (List): List of conversational dialogs, where each dialog is a list of messages. + NOTE if you are using a VLM, all dialogs must either all have images ("image" field) or all be pure text. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.01. + top_k (int, optional): Top-k probability threshold for nucleus sampling. Defaults to None. If not None, top-p sampling is ignored. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. If not None, top-k sampling is ignored. + max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + num_gen_seq (int, optional): Number of sequences to generate per prompt. Defaults to 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + generation_prefix (str, optional): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "". + compile_sampling (bool, optional): Flag indicating whether to compile the generation function. Defaults to False. + compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. + stop_tokens (Set[int], optional): Set of tokens to stop generation. Defaults to None. If not None, it will override the model's stop tokens. + verbose (bool, optional): Flag indicating whether to print the generation throughput. Defaults to False. + Returns: + List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. + + Note: + This method generates assistant responses for the provided conversational dialogs. + It employs nucleus sampling to introduce controlled randomness in text generation. + If logprobs is True, token log probabilities are computed for each generated token. + """ + if max_gen_len is None: + max_gen_len = model.model.params.max_seq_len - 1 + images = None + if isinstance(model.tokenizer.text_tokenizer, ImageTextTokenizer): + # Vision-language model + prompt_dicts = [ + model.tokenizer.text_tokenizer.apply_chat_template( + dialog, generation_prefix=generation_prefix, add_generation_prompt=True + ) + for dialog in dialogs + ] + prompt_tokens = [prompt_dict["input_ids"] for prompt_dict in prompt_dicts] + num_images = sum(["pixel_values" in prompt_dict for prompt_dict in prompt_dicts]) + assert num_images in [0, len(dialogs)], "For VLM, all dialogs must either all have images or all be pure text." + if num_images > 0: + images = torch.cat([prompt_dict["pixel_values"] for prompt_dict in prompt_dicts], dim=0) + else: + images = None + elif isinstance(model.tokenizer.text_tokenizer, TextTokenizer): + # Text-only model + prompt_tokens = [ + model.tokenizer.text_tokenizer.apply_chat_template( + dialog, generation_prefix=generation_prefix, add_generation_prompt=True + ) + for dialog in dialogs + ] + else: + prompt_tokens = [model.formatter.encode_dialog_prompt(dialog) for dialog in dialogs] + + generation_tokens, generation_logprobs = model.generate( + prompt_tokens=prompt_tokens, + seed=seed, + max_gen_len=max_gen_len, + num_gen_seq=num_gen_seq, + temperature=temperature, + top_k=top_k, + top_p=top_p, + compile_sampling=compile_sampling, + compile_prefill=compile_prefill, + stop_tokens=stop_tokens, + verbose=verbose, + images=images, + ) + + if logprobs: + return [ + { + "generation": { + "role": "assistant", + "content": model.tokenizer.text_tokenizer.decode(t), + }, + "tokens": [model.tokenizer.text_tokenizer.decode([x]) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [ + { + "generation": { + "role": "assistant", + "content": model.tokenizer.text_tokenizer.decode(t), + }, + } + for t in generation_tokens + ] diff --git a/cosmos_predict1/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py b/cosmos_predict1/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..16076d1afc06c5f217bfe9f65663cb829b534ca2 --- /dev/null +++ b/cosmos_predict1/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This demo script is used to run inference for Cosmos-UpsamplePrompt1-12B-Text2World. +Command: + CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python cosmos_predict1/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py + +""" +import argparse +import os +import re + +from cosmos_predict1.autoregressive.configs.base.model_config import create_text_model_config +from cosmos_predict1.autoregressive.model import AutoRegressiveModel +from cosmos_predict1.auxiliary.guardrail.common import presets as guardrail_presets +from cosmos_predict1.diffusion.prompt_upsampler.inference import chat_completion +from cosmos_predict1.utils import log + + +def create_prompt_upsampler(checkpoint_dir: str) -> AutoRegressiveModel: + model_config, tokenizer_config = create_text_model_config( + model_ckpt_path=os.path.join(checkpoint_dir, "model.pt"), + tokenizer_path=os.path.join(checkpoint_dir), + model_family="mistral", + model_size="12b", + is_instruct_model=True, + max_batch_size=1, + rope_dim="1D", + add_special_tokens=True, + max_seq_len=1024, + pytorch_rope_version="v1", + ) + log.debug(f"Text prompt upsampler model config: {model_config}") + + # Create and return a LLM instance + return AutoRegressiveModel.build( + model_config=model_config, + tokenizer_config=tokenizer_config, + ).to("cuda") + + +def run_chat_completion(model: AutoRegressiveModel, input: str, temperature: float = 0.01): + """ + text2world prompt upsampler model is finetuned for chat. + During training, the context window for the initial prompt upsampler models is 512 tokens. For inference, we set max_seq_len to 1024 to accommodate longer inputs. + Setting `max_gen_len` is optional as the finetuned models can naturally determine when to stop generating. + """ + + dialogs = [[{"role": "user", "content": f"Upsample the short caption to a long caption: {str(input)}"}]] + + results = chat_completion( + model, + dialogs, + max_gen_len=512, + temperature=temperature, + top_p=None, + top_k=None, + logprobs=False, + ) + upsampled_prompt = str(clean_text(results[0]["generation"]["content"])) + return upsampled_prompt + + +def clean_text(text: str) -> str: + """Clean the text by removing prefixes, suffixes, formatting markers, and normalizing whitespace.""" + # Replace all variations of newlines with a space + text = text.replace("\n", " ").replace("\r", " ") + + # Use a regex to find sections of the form '- **...**' + pattern = r"(- \*\*)(.*?)(\*\*)" + + def replacement(match: re.Match[str]) -> str: + content = match.group(2) # The text inside - ** and ** + words = re.findall(r"\w+", content) + if len(words) < 10: + # If fewer than 10 words, remove the entire '- **...**' portion + return "" + else: + # If 10 or more words, keep the entire section as it is + return match.group(0) + + text = re.sub(pattern, replacement, text) + + # Remove common prefixes + prefixes = ["Caption:", "#####", "####", "- ", "* ", ","] + for prefix in prefixes: + # lstrip(prefix) won't strip entire strings, but character sets. + # For more reliable prefix removal, do: + if text.startswith(prefix): + text = text[len(prefix) :].lstrip() + + # Remove extra spaces + text = " ".join(text.split()) + + # Strip any remaining leading/trailing punctuation, whitespace, and quotes + text = text.strip(' -,*:"\'"“”') + + return text + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run prompt upsampler inference") + parser.add_argument("--input", type=str, default="A dog is playing with a ball.") + parser.add_argument("--temperature", type=float, default=0.01, help="Inference temperature") + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Cosmos-UpsamplePrompt1-12B-Text2World", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + return parser.parse_args() + + +def main(args): + guardrail_runner = guardrail_presets.create_text_guardrail_runner(args.checkpoint_dir) + is_safe = guardrail_presets.run_text_guardrail(args.input, guardrail_runner) + if not is_safe: + log.critical("Input text prompt is not safe.") + return + + prompt_upsampler = create_prompt_upsampler(os.path.join(args.checkpoint_dir, args.prompt_upsampler_dir)) + upsampled_prompt = run_chat_completion(prompt_upsampler, args.input, temperature=args.temperature) + is_safe = guardrail_presets.run_text_guardrail(upsampled_prompt, guardrail_runner) + if not is_safe: + log.critical("Upsampled text prompt is not safe.") + return + + log.info(f"Upsampled prompt: {upsampled_prompt}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py b/cosmos_predict1/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f40c6cbfbfc078fb891bd75dc0f04236c12ced --- /dev/null +++ b/cosmos_predict1/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This demo script is used to run inference for Pixtral-12B. +Command: + CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python cosmos_predict1/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py + +""" + +import argparse +import os +from math import ceil + +from PIL import Image + +from cosmos_predict1.autoregressive.configs.base.model_config import create_vision_language_model_config +from cosmos_predict1.autoregressive.model import AutoRegressiveModel +from cosmos_predict1.auxiliary.guardrail.common import presets as guardrail_presets +from cosmos_predict1.diffusion.prompt_upsampler.inference import chat_completion +from cosmos_predict1.utils import log +from cosmos_predict1.utils.io import load_from_fileobj + + +def create_vlm_prompt_upsampler( + checkpoint_dir: str, tokenizer_ckpt_path: str = "mistral-community/pixtral-12b" +) -> AutoRegressiveModel: + """ + Load the fine-tuned pixtral model for SimReady. + If pixtral_ckpt is not provided, use the pretrained checkpoint. + """ + model_ckpt_path = os.path.join(checkpoint_dir, "model.pt") + model_config, tokenizer_config = create_vision_language_model_config( + model_ckpt_path=model_ckpt_path, + tokenizer_ckpt_path=tokenizer_ckpt_path, + model_family="pixtral", + model_size="12b", + is_instruct_model=True, + max_batch_size=1, + max_seq_len=4300, + pytorch_rope_version="v1", + ) + # during instantiate, the weights will be downloaded (if not already cached) and loaded + return AutoRegressiveModel.build( + model_config=model_config, + tokenizer_config=tokenizer_config, + ).to("cuda") + + +def resize_image(image: Image.Image, max_size: int = 1024) -> Image.Image: + """ + Ensure that the image is no larger than max_size in both dimensions. + """ + image_width, image_height = image.size + max_width, max_height = max_size, max_size + ratio = max(image_width / max_width, image_height / max_height) + if ratio > 1: + image = image.resize((ceil(image_width / ratio), ceil(image_height / ratio))) + return image + + +def prepare_dialog(image_or_video_path: str) -> list[dict]: + if image_or_video_path.endswith(".mp4"): + video_np, _ = load_from_fileobj(image_or_video_path, format="mp4") + image_frame = video_np[-1] + image = Image.fromarray(image_frame) + else: + image: Image.Image = Image.open(image_or_video_path) + + image = resize_image(image, max_size=1024) + prompt = """\ +Your task is to transform a given prompt into a refined and concise video description, no more than 150 words. +Focus only on the content, no filler words or descriptions on the style. Never mention things outside the video. + """.strip() + + return [ + { + "role": "user", + "content": "[IMG]\n" + prompt, + "images": [image], + } + ] + + +def run_chat_completion(pixtral: AutoRegressiveModel, dialog: list[dict], **inference_args) -> str: + default_args = { + "max_gen_len": 400, + "temperature": 0, + "top_p": 0.9, + "logprobs": False, + "compile_sampling": False, + "compile_prefill": False, + } + default_args.update(inference_args) + results = chat_completion( + pixtral, + [dialog], + **default_args, + ) + assert len(results) == 1 + upsampled_prompt = str(results[0]["generation"]["content"]) + return upsampled_prompt + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run prompt upsampler inference") + parser.add_argument("--image_or_video_path", type=str, default="assets/diffusion/video2world_input0.jpg") + parser.add_argument("--temperature", type=float, default=0.01, help="Inference temperature") + parser.add_argument("--top_p", type=float, default=0.9, help="Top-p value for top-p sampling") + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" + ) + parser.add_argument( + "--prompt_upsampler_dir", + type=str, + default="Pixtral-12B", + help="Prompt upsampler weights directory relative to checkpoint_dir", + ) + return parser.parse_args() + + +def main(args): + guardrail_runner = guardrail_presets.create_text_guardrail_runner(args.checkpoint_dir) + + pixtral = create_vlm_prompt_upsampler(os.path.join(args.checkpoint_dir, args.prompt_upsampler_dir)) + dialog = prepare_dialog(args.image_or_video_path) + upsampled_prompt = run_chat_completion( + pixtral, + dialog, + max_gen_len=400, + temperature=args.temperature, + top_p=args.top_p, + logprobs=False, + ) + is_safe = guardrail_presets.run_text_guardrail(upsampled_prompt, guardrail_runner) + if not is_safe: + log.critical("Upsampled text prompt is not safe.") + return + + log.info(f"Upsampled prompt: {upsampled_prompt}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/cosmos_predict1/diffusion/training/callbacks/every_n.py b/cosmos_predict1/diffusion/training/callbacks/every_n.py new file mode 100644 index 0000000000000000000000000000000000000000..25cab309a58336867ed5fc58849e71db7611d0f3 --- /dev/null +++ b/cosmos_predict1/diffusion/training/callbacks/every_n.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Optional + +import torch + +from cosmos_predict1.utils import distributed, log +from cosmos_predict1.utils.callback import Callback +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +class EveryN(Callback): + def __init__( + self, + every_n: Optional[int] = None, + step_size: int = 1, + barrier_after_run: bool = True, + run_at_start: bool = False, + ) -> None: + """Constructor for `EveryN`. + + Args: + every_n (int): Frequency with which callback is run during training. + step_size (int): Size of iteration step count. Default 1. + barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts. + run_at_start (bool): Whether to run at the beginning of training. Default False. + """ + self.every_n = every_n + if self.every_n == 0: + log.warning( + f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped." + ) + + self.step_size = step_size + self.barrier_after_run = barrier_after_run + self.run_at_start = run_at_start + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + # every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training + if self.every_n != 0: + trainer = self.trainer + global_step = iteration // self.step_size + should_run = (iteration == 1 and self.run_at_start) or ( + global_step % self.every_n == 0 + ) # (self.every_n - 1) + if should_run: + log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}") + self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration) + log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}") + # add necessary barrier to avoid timeout + if self.barrier_after_run: + distributed.barrier() + + @abstractmethod + def every_n_impl( + self, + trainer: Trainer, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int, + ) -> None: + ... diff --git a/cosmos_predict1/diffusion/training/callbacks/grad_clip.py b/cosmos_predict1/diffusion/training/callbacks/grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..8c060fc1a7514734cd3e429f5e3aba7c4c57f7e0 --- /dev/null +++ b/cosmos_predict1/diffusion/training/callbacks/grad_clip.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple + +import torch +from megatron.core import parallel_state +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from cosmos_predict1.utils import distributed +from cosmos_predict1.utils.callbacks.grad_clip import GradClip as GradClipImage +from cosmos_predict1.utils.callbacks.grad_clip import _fused_nan_to_num +from cosmos_predict1.utils.model import Model + + +@dataclass +class _MagnitudeRecord: + state: float = 0 + iter_count: int = 0 + + def reset(self) -> None: + self.state = 0 + self.iter_count = 0 + + def update(self, cur_state: torch.Tensor) -> None: + self.state += cur_state + self.iter_count += 1 + + def get_stat(self) -> Tuple[float, float]: + if self.iter_count > 0: + avg_state = self.state / self.iter_count + avg_state = avg_state.item() + else: + avg_state = 0 + self.reset() + return avg_state + + +class GradClip(GradClipImage): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.img_mag_log = _MagnitudeRecord() + self.video_mag_log = _MagnitudeRecord() + self._cur_state = None + + def on_training_step_start(self, model: Model, data_batch: dict[str, torch.Tensor], iteration: int = 0) -> None: + if model.is_image_batch(data_batch): + self._cur_state = self.img_mag_log + else: + self._cur_state = self.video_mag_log + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + del optimizer, scheduler + if isinstance(model_ddp, distributed.DistributedDataParallel): + model = model_ddp.module + else: + model = model_ddp + params = [] + if self.model_key is not None: + items = self.model_key.split(".") + for item in items: + model = getattr(model, item) + if self.force_finite: + for param in model.parameters(): + if param.grad is not None: + params.append(param.grad) + # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) + _fused_nan_to_num(params) + + if isinstance(model, FSDP) and self.fsdp_enabled: + total_norm = model.clip_grad_norm_(self.clip_norm) + else: + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + total_norm = model_ddp.module.clip_grad_norm_(self.clip_norm) + else: + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) + + self._cur_state.update(total_norm) diff --git a/cosmos_predict1/diffusion/training/callbacks/iter_speed.py b/cosmos_predict1/diffusion/training/callbacks/iter_speed.py new file mode 100644 index 0000000000000000000000000000000000000000..371a227c6db390c0e3764ad6bb3e278bdd1ae866 --- /dev/null +++ b/cosmos_predict1/diffusion/training/callbacks/iter_speed.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +from torch import Tensor + +from cosmos_predict1.diffusion.training.callbacks.every_n import EveryN +from cosmos_predict1.utils import log +from cosmos_predict1.utils.distributed import rank0_only +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +class IterSpeed(EveryN): + """ + Args: + hit_thres (int): Number of iterations to wait before logging. + """ + + def __init__(self, *args, hit_thres: int = 5, **kwargs): + super().__init__(*args, **kwargs) + self.time = None + self.hit_counter = 0 + self.hit_thres = hit_thres + self.name = self.__class__.__name__ + self.last_hit_time = time.time() + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + if self.hit_counter < self.hit_thres: + log.info( + f"Iteration {iteration}: " + f"Hit counter: {self.hit_counter + 1}/{self.hit_thres} | " + f"Loss: {loss.item():.4f} | " + f"Time: {time.time() - self.last_hit_time:.2f}s" + ) + self.hit_counter += 1 + self.last_hit_time = time.time() + #! useful for large scale training and avoid oom crash in the first two iterations!!! + torch.cuda.synchronize() + return + super().on_training_step_end(model, data_batch, output_batch, loss, iteration) + + @rank0_only + def every_n_impl( + self, + trainer: Trainer, + model: Model, + data_batch: dict[str, Tensor], + output_batch: dict[str, Tensor], + loss: Tensor, + iteration: int, + ) -> None: + if self.time is None: + self.time = time.time() + return + cur_time = time.time() + iter_speed = (cur_time - self.time) / self.every_n / self.step_size + + log.info(f"{iteration} : iter_speed {iter_speed:.2f} seconds per iteration | Loss: {loss.item():.4f}") + + self.time = cur_time diff --git a/cosmos_predict1/diffusion/training/callbacks/low_precision.py b/cosmos_predict1/diffusion/training/callbacks/low_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..faf9562c413f33f40d01b69e4bb01883283f0895 --- /dev/null +++ b/cosmos_predict1/diffusion/training/callbacks/low_precision.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from cosmos_predict1.diffusion.training.trainer import Trainer +from cosmos_predict1.utils.callback import LowPrecisionCallback as BaseCallback +from cosmos_predict1.utils.config import Config +from cosmos_predict1.utils.model import Model + + +class LowPrecisionCallback(BaseCallback): + """ + Config with non-primitive type makes it difficult to override the option. + The callback gets precision from model.precision instead. + """ + + def __init__(self, config: Config, trainer: Trainer, update_iter: int): + self.config = config + self.trainer = trainer + self.update_iter = update_iter + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + assert model.precision in [ + torch.bfloat16, + torch.float16, + torch.half, + ], "LowPrecisionCallback must use a low precision dtype." + self.precision_type = model.precision diff --git a/cosmos_predict1/diffusion/training/conditioner.py b/cosmos_predict1/diffusion/training/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..fee7bfb14ad9d0f9346aed41304401d6383c8d2f --- /dev/null +++ b/cosmos_predict1/diffusion/training/conditioner.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, fields +from enum import Enum +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from cosmos_predict1.diffusion.conditioner import GeneralConditioner +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp +from cosmos_predict1.utils.misc import count_params + + +class DataType(Enum): + IMAGE = "image" + VIDEO = "video" + MIX = "mix" + + +class AbstractEmbModel(nn.Module): + def __init__(self): + super().__init__() + + self._is_trainable = None + self._dropout_rate = None + self._input_key = None + self._return_dict = False + + @property + def is_trainable(self) -> bool: + return self._is_trainable + + @property + def dropout_rate(self) -> Union[float, torch.Tensor]: + return self._dropout_rate + + @property + def input_key(self) -> str: + return self._input_key + + @property + def is_return_dict(self) -> bool: + return self._return_dict + + @is_trainable.setter + def is_trainable(self, value: bool): + self._is_trainable = value + + @dropout_rate.setter + def dropout_rate(self, value: Union[float, torch.Tensor]): + self._dropout_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_return_dict.setter + def is_return_dict(self, value: bool): + self._return_dict = value + + @is_trainable.deleter + def is_trainable(self): + del self._is_trainable + + @dropout_rate.deleter + def dropout_rate(self): + del self._dropout_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + @is_return_dict.deleter + def is_return_dict(self): + del self._return_dict + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + del key + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + return batch_mul( + torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor), + in_tensor, + ) + + def details(self) -> str: + return "" + + def summary(self) -> str: + input_key = self.input_key if self.input_key is not None else getattr(self, "input_keys", None) + return ( + f"{self.__class__.__name__} \n\tinput key: {input_key}" + f"\n\tParam count: {count_params(self, False)} \n\tTrainable: {self.is_trainable}" + f"\n\tDropout rate: {self.dropout_rate}" + f"\n\t{self.details()}" + ) + + +class TrajectoryAttr(AbstractEmbModel): + def __init__(self, traj_dim: int): + super().__init__() + self.traj_dim = traj_dim + + def forward(self, traj: torch.Tensor) -> Dict[str, torch.Tensor]: + return { + "trajectory": traj, + } + + def details(self) -> str: + return f"Traj dim : {self.traj_dim} \n\tOutput key: [trajectory]" + + +class FrameRepeatAttr(AbstractEmbModel): + def __init__(self): + super().__init__() + + def forward(self, frame_repeat: torch.Tensor) -> Dict[str, torch.Tensor]: + return { + "frame_repeat": frame_repeat / 10.0, + } + + def details(self) -> str: + return "Frame repeat, Output key: [frame_repeat]" + + +@dataclass +class BaseVideoCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + data_type: DataType = DataType.VIDEO + padding_mask: Optional[torch.Tensor] = None + fps: Optional[torch.Tensor] = None + num_frames: Optional[torch.Tensor] = None + image_size: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + trajectory: Optional[torch.Tensor] = None + frame_repeat: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +@dataclass +class VideoExtendCondition(BaseVideoCondition): + video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video + gt_latent: Optional[torch.Tensor] = None + condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region + + # condition_video_input_mask will concat to the input of network, along channel dim; + # Will be concat with the input tensor + condition_video_input_mask: Optional[torch.Tensor] = None + # condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" + condition_video_augment_sigma: Optional[torch.Tensor] = None + # pose conditional input, will be concat with the input tensor + condition_video_pose: Optional[torch.Tensor] = None + + +@dataclass +class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): + # latent_condition will concat to the input of network, along channel dim; + # cfg will make latent_condition all zero padding. + latent_condition: Optional[torch.Tensor] = None + latent_condition_sigma: Optional[torch.Tensor] = None + + def get_condition_for_cp(self, cp_group): + self.latent_condition = split_inputs_cp(x=self.latent_condition, seq_dim=2, cp_group=cp_group) + self.latent_condition_sigma = split_inputs_cp(x=self.latent_condition_sigma, seq_dim=2, cp_group=cp_group) + + +class VideoConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> BaseVideoCondition: + output = super()._forward(batch, override_dropout_rate) + return BaseVideoCondition(**output) + + +class VideoDiffusionDecoderConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoLatentDiffusionDecoderCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoLatentDiffusionDecoderCondition(**output) + + +class VideoExtendConditioner(GeneralConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoExtendCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoExtendCondition(**output) + + +class VideoConditionerWithTraingOnlyEmb(GeneralConditioner): + def get_condition_uncondition( + self, + data_batch: Dict, + ) -> Tuple[Any, Any]: + """ + Processes the provided data batch to generate two sets of outputs: conditioned and unconditioned. This method + manipulates the dropout rates of embedders to simulate two scenarios — one where all conditions are applied + (conditioned), and one where they are removed or reduced to the minimum (unconditioned). + + This method first sets the dropout rates to zero for the conditioned scenario to fully apply the embedders' effects. + For the unconditioned scenario, it sets the dropout rates to 1 (or to 0 if the initial unconditional dropout rate + is insignificant) to minimize the embedders' influences, simulating an unconditioned generation. + + Parameters: + data_batch (Dict): The input data batch that contains all necessary information for embedding processing. The + data is expected to match the required format and keys expected by the embedders. + + Returns: + Tuple[Any, Any]: A tuple containing two condition: + - The first one contains the outputs with all embedders fully applied (conditioned outputs). + - The second one contains the outputs with embedders minimized or not applied (unconditioned outputs). + """ + cond_dropout_rates, dropout_rates = {}, {} + for emb_name, embedder in self.embedders.items(): + if isinstance(embedder, FrameRepeatAttr): + cond_dropout_rates[emb_name] = 1.0 + else: + cond_dropout_rates[emb_name] = 0.0 + dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 + + condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) + un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) + return condition, un_condition + + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> BaseVideoCondition: + output = super()._forward(batch, override_dropout_rate) + return BaseVideoCondition(**output) + + +class VideoExtendConditionerWithTraingOnlyEmb(VideoConditionerWithTraingOnlyEmb): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> VideoExtendCondition: + output = super()._forward(batch, override_dropout_rate) + return VideoExtendCondition(**output) + + +@dataclass +class BaseWithCtrlCondition(VideoExtendCondition): + control_input_canny: Optional[torch.Tensor] = None + control_input_blur: Optional[torch.Tensor] = None + control_input_canny_blur: Optional[torch.Tensor] = None + control_input_depth: Optional[torch.Tensor] = None + control_input_segmentation: Optional[torch.Tensor] = None + control_input_depth_segmentation: Optional[torch.Tensor] = None + control_input_mask: Optional[torch.Tensor] = None + control_input_human_kpts: Optional[torch.Tensor] = None + control_input_upscale: Optional[torch.Tensor] = None + control_input_identity: Optional[torch.Tensor] = None + control_input_multi: Optional[torch.Tensor] = None + base_model: Optional[torch.nn.Module] = None + hint_key: Optional[str] = None + control_weight: Optional[float] = 1.0 + num_layers_to_use: Optional[int] = -1 + + +class VideoConditionerWithCtrl(VideoExtendConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> BaseWithCtrlCondition: + output = super()._forward(batch, override_dropout_rate) + output["hint_key"] = batch["hint_key"] + if "control_weight" in batch: + output["control_weight"] = batch["control_weight"] + if "num_layers_to_use" in batch: + output["num_layers_to_use"] = batch["num_layers_to_use"] + return BaseWithCtrlCondition(**output) + + +class BooleanFlag(AbstractEmbModel): + def __init__(self, output_key: Optional[str] = None): + super().__init__() + self.output_key = output_key + + def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + del args, kwargs + key = self.output_key if self.output_key else self.input_key + return {key: self.flag} + + def random_dropout_input( + self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None + ) -> torch.Tensor: + del key + dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate + self.flag = torch.bernoulli((1.0 - dropout_rate) * torch.ones(1)).bool().to(device=in_tensor.device) + return in_tensor + + def details(self) -> str: + key = self.output_key if self.output_key else self.input_key + return f"Output key: {key} \n\t This is a boolean flag" diff --git a/cosmos_predict1/diffusion/training/config/base/ema.py b/cosmos_predict1/diffusion/training/config/base/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..30eea4fe92403590d6812403e73b46ff8d4bded4 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/base/ema.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.utils.ema import EMAModelTracker, PowerEMATracker +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +PowerEMAConfig: LazyDict = L(PowerEMATracker.initialize_multi_rank_ema)( + model=PLACEHOLDER, enabled=True, rate=0.10, num=3 +) + +RegEMAConfig: LazyDict = L(EMAModelTracker.initialize_multi_rank_ema)( + model=PLACEHOLDER, enabled=True, rate=0.999, num=1 +) diff --git a/cosmos_predict1/diffusion/training/config/base/model.py b/cosmos_predict1/diffusion/training/config/base/model.py new file mode 100644 index 0000000000000000000000000000000000000000..47e0b76154805d61876acc486ccc5e13eb7c7fdb --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/base/model.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import attrs + +from cosmos_predict1.diffusion.training.config.base.ema import PowerEMAConfig +from cosmos_predict1.diffusion.training.modules.edm_sde import EDMSDE +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class FSDPConfig: + policy: str = "block" + checkpoint: bool = False + min_num_params: int = 1024 + sharding_group_size: int = 8 + sharding_strategy: str = "full" + + +@attrs.define(slots=False) +class DefaultModelConfig: + vae: LazyDict = None + conditioner: LazyDict = None + net: LazyDict = None + ema: LazyDict = PowerEMAConfig + sde: LazyDict = L(EDMSDE)( + p_mean=0.0, + p_std=1.0, + sigma_max=80, + sigma_min=0.0002, + ) + sigma_data: float = 0.5 + camera_sample_weight: LazyDict = LazyDict( + dict( + enabled=False, + weight=5.0, + ) + ) + aesthetic_finetuning: LazyDict = LazyDict( + dict( + enabled=False, + ) + ) + loss_mask_enabled: bool = False + loss_masking: LazyDict = None + loss_add_logvar: bool = True + precision: str = "bfloat16" + input_data_key: str = "video" # key to fetch input data from data_batch + input_image_key: str = "images_1024" # key to fetch input image from data_batch + loss_reduce: str = "sum" + loss_scale: float = 1.0 + latent_shape: List[int] = [16, 24, 44, 80] # 24 corresponig to 136 frames + fsdp_enabled: bool = False + use_torch_compile: bool = False + fsdp: FSDPConfig = attrs.field(factory=FSDPConfig) + use_dummy_temporal_dim: bool = False # Whether to use dummy temporal dimension in data + adjust_video_noise: bool = False # whether or not adjust video noise accroding to the video length + peft_control: LazyDict | None = None + + +@attrs.define(slots=False) +class MultiviewModelConfig(DefaultModelConfig): + n_views: int = 6 diff --git a/cosmos_predict1/diffusion/training/config/base/optim.py b/cosmos_predict1/diffusion/training/config/base/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6ddb236d2eb97576e521ca4308166e628047ed --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/base/optim.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.diffusion.training.functional.lr_scheduler import LambdaLinearScheduler +from cosmos_predict1.diffusion.training.utils.optim_instantiate import get_base_optimizer +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +FusedAdamWConfig: LazyDict = L(get_base_optimizer)( + model=PLACEHOLDER, + lr=1e-4, + weight_decay=0.3, + betas=[0.9, 0.999], + optim_type="fusedadam", + eps=1e-8, + sharding=False, + master_weights=True, + capturable=True, +) + +LambdaLinearSchedulerConfig: LazyDict = L(LambdaLinearScheduler)( + warm_up_steps=[1000], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], +) diff --git a/cosmos_predict1/diffusion/training/config/base/vae.py b/cosmos_predict1/diffusion/training/config/base/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..8d87a1f78e0522a54cb4dc7772235b36856d6ad9 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/base/vae.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import omegaconf + +from cosmos_predict1.diffusion.training.module.pretrained_vae import VideoJITTokenizer +from cosmos_predict1.utils.lazy_config import LazyCall as L + +TOKENIZER_OPTIONS = {} + + +def tokenizer_register(key): + def decorator(func): + TOKENIZER_OPTIONS[key] = func + return func + + return decorator + + +@tokenizer_register("cosmos_diffusion_tokenizer_comp8x8x8") +def get_cosmos_tokenizer_comp8x8x8( + resolution: str, + chunk_duration: int, +) -> omegaconf.dictconfig.DictConfig: + assert resolution in ["512", "720"] + + pixel_chunk_duration = chunk_duration + temporal_compression_factor = 8 + spatial_compression_factor = 8 + + return L(VideoJITTokenizer)( + name="cosmos_diffusion_tokenizer_comp8x8x8", + enc_fp="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", + dec_fp="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", + mean_std_fp="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", + latent_ch=16, + is_bf16=True, + pixel_chunk_duration=pixel_chunk_duration, + temporal_compression_factor=temporal_compression_factor, + spatial_compression_factor=spatial_compression_factor, + spatial_resolution=resolution, + ) diff --git a/cosmos_predict1/diffusion/training/config/config.py b/cosmos_predict1/diffusion/training/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..38e82f6f0f6e3b52dc66f960dc0ed019e7b381c1 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/config.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from cosmos_predict1.diffusion.training.config.base.model import DefaultModelConfig +from cosmos_predict1.diffusion.training.config.text2world.registry import ( + register_configs as register_configs_text2world, +) +from cosmos_predict1.diffusion.training.config.video2world.registry import ( + register_configs as register_configs_video2world, +) +from cosmos_predict1.diffusion.training.config.video2world_action.registry import ( + register_configs as register_configs_video2world_action, +) +from cosmos_predict1.diffusion.training.config.video2world_instruction.registry import ( + register_configs as register_configs_video2world_instruction, +) +from cosmos_predict1.diffusion.training.models.model import DiffusionModel +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config_helper import import_all_modules_from_package +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.trainer import Trainer + + +@attrs.define(slots=False) +class Config(config.Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"data_train": None}, + {"data_val": None}, + {"optimizer": "fusedadamw"}, + {"scheduler": "lambdalinear"}, + {"callbacks": None}, + {"net": None}, + {"conditioner": "add_fps_image_size_padding_mask"}, + {"fsdp": None}, + {"ema": "power"}, + {"vae": "vae1"}, + {"checkpoint": "pbss"}, + {"ckpt_klass": "fsdp"}, + # the list is with order, we need global experiment to be the last one + {"experiment": None}, + ] + ) + model_obj: LazyDict = L(DiffusionModel)( + config=PLACEHOLDER, + ) + + +def make_config(): + c = Config( + model=DefaultModelConfig(), + optimizer=None, + scheduler=None, + dataloader_train=None, + dataloader_val=None, + ) + + # Specifying values through instances of attrs + c.job.project = "cosmos_predict1" + c.job.group = "debug" + c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + c.trainer.type = Trainer + c.trainer.max_iter = 400_000 + c.trainer.logging_iter = 10 + c.trainer.validation_iter = 100 + c.trainer.run_validation = False + c.trainer.callbacks = None + + c.checkpoint = None + + # Call this function to register config groups. + register_configs_text2world() + register_configs_video2world() + register_configs_video2world_instruction() + register_configs_video2world_action() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.text2world", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world_instruction", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world_action", reload=True) + + return c diff --git a/cosmos_predict1/diffusion/training/config/config_multiview.py b/cosmos_predict1/diffusion/training/config/config_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..fabf37c5643d0d4c1f1d1a8d90df7eacae511963 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/config_multiview.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import attrs + +from cosmos_predict1.diffusion.training.config.base.model import MultiviewModelConfig +from cosmos_predict1.diffusion.training.config.text2world.registry import ( + register_configs as register_configs_text2world, +) +from cosmos_predict1.diffusion.training.config.text2world_multiview.registry import ( + register_configs as register_configs_text2world_multiview, +) +from cosmos_predict1.diffusion.training.config.video2world.registry import ( + register_configs as register_configs_video2world, +) +from cosmos_predict1.diffusion.training.config.video2world_multiview.registry import ( + register_configs as register_configs_video2world_multiview, +) +from cosmos_predict1.diffusion.training.models.model import DiffusionModel +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config_helper import import_all_modules_from_package +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.trainer import Trainer + + +@attrs.define(slots=False) +class Config(config.Config): + # default config groups that will be used unless overwritten + # see config groups in registry.py + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"data_train": None}, + {"data_val": None}, + {"optimizer": "fusedadamw"}, + {"scheduler": "lambdalinear"}, + {"callbacks": None}, + {"net": None}, + {"conditioner": "add_fps_image_size_padding_mask"}, + {"fsdp": None}, + {"ema": "power"}, + {"vae": "vae1"}, + {"checkpoint": "pbss"}, + {"ckpt_klass": "fsdp"}, + # the list is with order, we need global experiment to be the last one + {"experiment": None}, + ] + ) + model_obj: LazyDict = L(DiffusionModel)( + config=PLACEHOLDER, + ) + + +def make_config(): + c = Config( + model=MultiviewModelConfig(), + optimizer=None, + scheduler=None, + dataloader_train=None, + dataloader_val=None, + ) + + # Specifying values through instances of attrs + c.job.project = "cosmos_predict1" + c.job.group = "debug" + c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + c.trainer.type = Trainer + # c.trainer.straggler_detection.enabled = False + c.trainer.max_iter = 400_000 + c.trainer.logging_iter = 10 + c.trainer.validation_iter = 100 + c.trainer.run_validation = False + c.trainer.callbacks = None + + c.checkpoint = None + + # Call this function to register config groups. + register_configs_text2world() + register_configs_video2world() + register_configs_text2world_multiview() + register_configs_video2world_multiview() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.text2world", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.text2world_multiview", reload=True) + import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world_multiview", reload=True) + + return c diff --git a/cosmos_predict1/diffusion/training/config/text2world/experiment.py b/cosmos_predict1/diffusion/training/config/text2world/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..5e67f0585f515472185ef545b272376da27374df --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/text2world/experiment.py @@ -0,0 +1,1020 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_video import Dataset +from cosmos_predict1.diffusion.training.models.model import FSDPDiffusionModel +from cosmos_predict1.diffusion.training.models.model_peft import PEFTVideoDiffusionModel +from cosmos_predict1.diffusion.training.utils.peft.lora_config import get_fa_ca_qv_lora_config +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +n_length = 15 +num_frames = 8 * n_length + 1 # 121 + +# HDVILA example +example_video_dataset_hdvila = L(Dataset)( + dataset_dir="datasets/hdvila", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train_hdvila = L(DataLoader)( + dataset=example_video_dataset_hdvila, + sampler=L(get_sampler)(dataset=example_video_dataset_hdvila), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_hdvila = L(DataLoader)( + dataset=example_video_dataset_hdvila, + sampler=L(get_sampler)(dataset=example_video_dataset_hdvila), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +# Cosmos-NeMo-Assets example +example_video_dataset_cosmos_nemo_assets = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_cosmos_nemo_assets = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +# Cosmos-NeMo-Assets 480x848 example for lora +example_video_dataset_cosmos_nemo_assets_480_848 = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames, + video_size=(480, 848), + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_480_848 = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_480_848, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_480_848), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_cosmos_nemo_assets_480_848 = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_480_848, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_480_848), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +# Cosmos-NeMo-Assets examples with more affordable GPUs setup (4 GPUs or 40GB VRAM) +n_length_4gpu_80gb = 15 +num_frames_4gpu_80gb = 8 * n_length_4gpu_80gb + 1 # 121 +example_video_dataset_cosmos_nemo_assets_4gpu_80gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_4gpu_80gb, + video_size=(384, 384), # a low-res example for lower VRAM utilization without considering the content aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_4gpu_80gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_cosmos_nemo_assets_4gpu_80gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +n_length_8gpu_40gb = 4 +num_frames_8gpu_40gb = 8 * n_length_8gpu_40gb + 1 # 33 +example_video_dataset_cosmos_nemo_assets_8gpu_40gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_8gpu_40gb, + video_size=(384, 384), # a low-res example for lower VRAM utilization without considering aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_8gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_cosmos_nemo_assets_8gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +n_length_4gpu_40gb = 2 +num_frames_4gpu_40gb = 8 * n_length_4gpu_40gb + 1 # 17 +example_video_dataset_cosmos_nemo_assets_4gpu_40gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_4gpu_40gb, + video_size=(384, 384), # a low-res example for lower VRAM utilization without considering aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_4gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) +dataloader_val_cosmos_nemo_assets_4gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + + +text2world_7b_example_hdvila = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_example_hdvila", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + grad_accum_iter=2, + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + loss_scale=10.0, + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + vae=dict(pixel_chunk_duration=num_frames), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_hdvila, + dataloader_val=dataloader_val_hdvila, + ) +) + + +text2world_14b_example_hdvila = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_14b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_14b_example_hdvila", + ), + optimizer=dict( + lr=2 ** (-16), + weight_decay=0.2, + betas=[0.9, 0.99], + eps=1e-11, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-14B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=8, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + loss_scale=10.0, + ema=dict( + enabled=True, + num=1, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=False, + min_num_params=1024, + sharding_group_size=64, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=2.0, + rope_t_extrapolation_ratio=2.0, + rope_w_extrapolation_ratio=2.0, + extra_h_extrapolation_ratio=2.0, + extra_t_extrapolation_ratio=2.0, + extra_w_extrapolation_ratio=2.0, + use_memory_save=True, + ), + adjust_video_noise=True, + vae=dict(pixel_chunk_duration=num_frames), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[90_000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1e-1], + ), + dataloader_train=dataloader_train_hdvila, + dataloader_val=dataloader_val_hdvila, + ) +) + +text2world_7b_example_cosmos_nemo_assets = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_example_cosmos_nemo_assets", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + vae=dict(pixel_chunk_duration=num_frames), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets, + dataloader_val=dataloader_val_cosmos_nemo_assets, + ) +) + +text2world_7b_example_cosmos_nemo_assets_4gpu_80gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_example_cosmos_nemo_assets_4gpu_80gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + use_memory_save=False, + ), + vae=dict( + pixel_chunk_duration=num_frames_4gpu_80gb, + spatial_resolution="384", + ), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_4gpu_80gb, + dataloader_val=dataloader_val_cosmos_nemo_assets_4gpu_80gb, + ) +) + +text2world_7b_example_cosmos_nemo_assets_8gpu_40gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_example_cosmos_nemo_assets_8gpu_40gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, # set to False to save memory + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=False, # turn off to save memory + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + use_memory_save=False, + ), + vae=dict( + pixel_chunk_duration=num_frames_8gpu_40gb, + spatial_resolution="384", + ), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_8gpu_40gb, + dataloader_val=dataloader_val_cosmos_nemo_assets_8gpu_40gb, + ) +) + +text2world_7b_example_cosmos_nemo_assets_4gpu_40gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_example_cosmos_nemo_assets_4gpu_40gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, # set to False to save memory + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=False, # turn off to save memory + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + use_memory_save=False, + ), + vae=dict( + pixel_chunk_duration=num_frames_4gpu_40gb, + spatial_resolution="384", + ), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_4gpu_40gb, + dataloader_val=dataloader_val_cosmos_nemo_assets_4gpu_40gb, + ) +) + + +text2world_14b_example_cosmos_nemo_assets = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_14b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_14b_example_cosmos_nemo_assets", + ), + optimizer=dict( + lr=2 ** (-16), + weight_decay=0.2, + betas=[0.9, 0.99], + eps=1e-11, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-14B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=16, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + loss_scale=10.0, + ema=dict( + enabled=True, + num=1, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=False, + min_num_params=1024, + sharding_group_size=64, + sharding_strategy="hybrid", + ), + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=2.0, + rope_t_extrapolation_ratio=2.0, + rope_w_extrapolation_ratio=2.0, + extra_h_extrapolation_ratio=2.0, + extra_t_extrapolation_ratio=2.0, + extra_w_extrapolation_ratio=2.0, + use_memory_save=True, + ), + adjust_video_noise=True, + vae=dict(pixel_chunk_duration=num_frames), + conditioner=dict(text=dict(dropout_rate=0.0)), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[90_000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1e-1], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets, + dataloader_val=dataloader_val_cosmos_nemo_assets, + ) +) + +text2world_7b_lora_example_cosmos_nemo_assets = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "peft"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_7b_lora_example_cosmos_nemo_assets", + ), + optimizer=dict( + lr=1e-4, + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=1000, + broadcast_via_filesystem=True, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, + ), + trainer=dict( + max_iter=5000, + distributed_parallelism="ddp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=False, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=4, + ), + model=dict( + peft_control=get_fa_ca_qv_lora_config(first_nblocks=28, rank=8, scale=1), + # Use 16x16x32x40 latent shape for training + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=False, + net=dict( + in_channels=16, + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(PEFTVideoDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + scheduler=dict( + warm_up_steps=[0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_480_848, + dataloader_val=dataloader_val_cosmos_nemo_assets_480_848, + ) +) + + +def register_experiments(cs: ConfigStore) -> None: + # Register the experiments + for _item in [ + text2world_7b_example_hdvila, + text2world_14b_example_hdvila, + text2world_7b_example_cosmos_nemo_assets, + text2world_14b_example_cosmos_nemo_assets, + text2world_7b_example_cosmos_nemo_assets_4gpu_80gb, + text2world_7b_example_cosmos_nemo_assets_8gpu_40gb, + text2world_7b_example_cosmos_nemo_assets_4gpu_40gb, + text2world_7b_lora_example_cosmos_nemo_assets, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/text2world/registry.py b/cosmos_predict1/diffusion/training/config/text2world/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..4a292447ec51a1c9a710d9c8591803f41a59464c --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/text2world/registry.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Dict + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.checkpointer.peft_checkpointer import Checkpointer as PEFTCheckpointer +from cosmos_predict1.diffusion.checkpointers.ema_fsdp_checkpointer import CheckpointConfig, FSDPCheckpointer +from cosmos_predict1.diffusion.conditioner import VideoExtendConditioner +from cosmos_predict1.diffusion.config.base.conditioner import ( + FPSConfig, + ImageSizeConfig, + NumFramesConfig, + PaddingMaskConfig, + TextConfig, + VideoCondBoolConfig, +) +from cosmos_predict1.diffusion.training.conditioner import VideoConditioner +from cosmos_predict1.diffusion.training.config.base.optim import FusedAdamWConfig, LambdaLinearSchedulerConfig +from cosmos_predict1.diffusion.training.config.base.vae import get_cosmos_tokenizer_comp8x8x8 +from cosmos_predict1.diffusion.training.config.text2world.experiment import register_experiments +from cosmos_predict1.diffusion.training.networks.general_dit import GeneralDIT +from cosmos_predict1.utils.ema import PowerEMATracker +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +FSDP_CHECKPOINTER: Dict[str, str] = L(FSDPCheckpointer)() +PEFT_CHECKPOINTER: Dict[str, str] = L(PEFTCheckpointer)() +VideoExtendConditionerConfig: LazyDict = L(VideoExtendConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + video_cond_bool=VideoCondBoolConfig(), +) + + +VideoConditionerFpsSizePaddingConfig: LazyDict = L(VideoConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), +) + + +def register_conditioner(cs): + cs.store( + group="conditioner", + package="model.conditioner", + name="video_cond", + node=VideoExtendConditionerConfig, + ) + + cs.store( + group="conditioner", + package="model.conditioner", + name="add_fps_image_size_padding_mask", + node=VideoConditionerFpsSizePaddingConfig, + ) + + +def register_checkpoint_credential(cs): + CHECKPOINT_LOCAL = CheckpointConfig( + save_iter=1000, + load_path="", + load_training_state=False, + strict_resume=True, + ) + + cs.store(group="checkpoint", package="checkpoint", name="local", node=CHECKPOINT_LOCAL) + + +def register_checkpointer(cs): + cs.store(group="ckpt_klass", package="checkpoint.type", name="fsdp", node=FSDP_CHECKPOINTER) + cs.store(group="ckpt_klass", package="checkpoint.type", name="peft", node=PEFT_CHECKPOINTER) + + +FADITV2Config: LazyDict = L(GeneralDIT)( + max_img_h=240, + max_img_w=240, + max_frames=128, + in_channels=16, + out_channels=16, + patch_spatial=2, + patch_temporal=1, + model_channels=4096, + block_config="FA-CA-MLP", + spatial_attn_win_size=1, + temporal_attn_win_size=1, + num_blocks=28, + num_heads=32, + concat_padding_mask=True, + pos_emb_cls="rope3d", + pos_emb_learnable=False, + pos_emb_interpolation="crop", + block_x_format="THWBD", + additional_timestamp_channels=None, + affline_emb_norm=True, + use_adaln_lora=True, + adaln_lora_dim=256, + legacy_patch_emb=False, +) + +FADITV2_14B_Config = copy.deepcopy(FADITV2Config) +FADITV2_14B_Config.model_channels = 5120 +FADITV2_14B_Config.num_heads = 40 +FADITV2_14B_Config.num_blocks = 36 + + +def register_net(cs): + cs.store(group="net", package="model.net", name="faditv2_7b", node=FADITV2Config) + cs.store(group="net", package="model.net", name="faditv2_14b", node=FADITV2_14B_Config) + + +def register_vae(cs): + cs.store( + group="vae", + package="model.vae", + name="cosmos_diffusion_tokenizer_comp8x8x8", + node=get_cosmos_tokenizer_comp8x8x8(resolution="720", chunk_duration=121), + ) + + +PowerEMAConfig: LazyDict = L(PowerEMATracker.initialize_multi_rank_ema)( + model=PLACEHOLDER, enabled=True, rate=0.10, num=3 +) + + +def register_ema(cs): + cs.store(group="ema", package="model.ema", name="power", node=PowerEMAConfig) + + +def register_optimizer(cs): + cs.store(group="optimizer", package="optimizer", name="fusedadamw", node=FusedAdamWConfig) + + +def register_scheduler(cs): + cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearSchedulerConfig) + + +def register_configs(): + cs = ConfigStore.instance() + + register_optimizer(cs) + register_scheduler(cs) + + register_net(cs) + register_conditioner(cs) + register_vae(cs) + + register_ema(cs) + + register_checkpoint_credential(cs) + register_checkpointer(cs) + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/config/text2world_multiview/experiment.py b/cosmos_predict1/diffusion/training/config/text2world_multiview/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9b9a213f89879cc8cfe8e5854743e86bbeec33 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/text2world_multiview/experiment.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_multiview import Dataset +from cosmos_predict1.diffusion.training.models.model_multiview import FSDPDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_multiview import MultiviewGeneralDIT +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +num_frames = 57 +num_views = 5 +view_keys = ["pinhole_front_left", "pinhole_front", "pinhole_front_right", "pinhole_side_left", "pinhole_side_right"] +example_multiview_dataset_waymo = L(Dataset)( + dataset_dir="datasets/waymo", + sequence_interval=1, + num_frames=num_frames, + view_keys=view_keys, + video_size=(480, 848), +) + + +text2world_multiview_7b_example_waymo = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + {"override /conditioner": "add_fps_image_size_padding_mask"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_text2world", + name="text2world_multiview_7b_example_waymo", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + # broadcast_via_filesystem=True, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=200, + hit_thres=5, + ), + # manual_gc=L(ManualGarbageCollection)(every_n=5), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + n_views=num_views, + # Use 16x16x32x40 latent shape for training + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(MultiviewGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + n_views=num_views, + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(FSDPDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps~(when resume from 310000) + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=L(DataLoader)( + dataset=example_multiview_dataset_waymo, + sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, + ), + dataloader_val=L(DataLoader)( + dataset=example_multiview_dataset_waymo, + sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, + ), + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + text2world_multiview_7b_example_waymo, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/text2world_multiview/registry.py b/cosmos_predict1/diffusion/training/config/text2world_multiview/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..c12886435fede2a3ab1a9d6f5fafda1c8ef4f019 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/text2world_multiview/registry.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.config.text2world_multiview.experiment import register_experiments + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/config/video2world/experiment.py b/cosmos_predict1/diffusion/training/config/video2world/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..0817cf30119c6229cef2e8bfb0e3960117664986 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world/experiment.py @@ -0,0 +1,846 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_video import Dataset +from cosmos_predict1.diffusion.training.models.extend_model import FSDPExtendDiffusionModel +from cosmos_predict1.diffusion.training.models.model_peft import PEFTExtendDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_lvg import VideoExtendGeneralDIT +from cosmos_predict1.diffusion.training.utils.peft.lora_config import get_fa_ca_qv_lora_config +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +n_length = 15 +num_frames = 8 * n_length + 1 # 121 + +# HDVILA example +example_video_dataset_hdvila = L(Dataset)( + dataset_dir="datasets/hdvila", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train_hdvila = L(DataLoader)( + dataset=example_video_dataset_hdvila, + sampler=L(get_sampler)(dataset=example_video_dataset_hdvila), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + +# Cosmos-NeMo-Assets example +example_video_dataset_cosmos_nemo_assets = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + +# Cosmos-NeMo-Assets examples with more affordable GPUs setup (4 GPUs or 40GB VRAM) +n_length_4gpu_80gb = 15 +num_frames_4gpu_80gb = 8 * n_length_4gpu_80gb + 1 # 121 +example_video_dataset_cosmos_nemo_assets_4gpu_80gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_4gpu_80gb, + video_size=(384, 384), # a low-res example for lower VRAM utilization without considering the content aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_4gpu_80gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_80gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +n_length_8gpu_40gb = 3 +num_frames_8gpu_40gb = 8 * n_length_8gpu_40gb + 1 # 25 +example_video_dataset_cosmos_nemo_assets_8gpu_40gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_8gpu_40gb, + video_size=(384, 384), # a low-res example for lower VRAM utilization without considering aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_8gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_8gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=8, + pin_memory=True, +) + +n_length_4gpu_40gb = 3 +num_frames_4gpu_40gb = 8 * n_length_4gpu_40gb + 1 # 25 +example_video_dataset_cosmos_nemo_assets_4gpu_40gb = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames_4gpu_40gb, + video_size=(192, 192), # a low-res example for lower VRAM utilization without considering aspect ratio. + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_4gpu_40gb = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_4gpu_40gb), + batch_size=1, + drop_last=True, + num_workers=0, + pin_memory=True, +) + +# Cosmos-NeMo-Assets 480x848 example for lora +example_video_dataset_cosmos_nemo_assets_480_848 = L(Dataset)( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=num_frames, + video_size=(480, 848), + start_frame_interval=1, +) + +dataloader_train_cosmos_nemo_assets_480_848 = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_480_848, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_480_848), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + +dataloader_val_cosmos_nemo_assets_480_848 = L(DataLoader)( + dataset=example_video_dataset_cosmos_nemo_assets_480_848, + sampler=L(get_sampler)(dataset=example_video_dataset_cosmos_nemo_assets_480_848), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + +video2world_7b_example_hdvila = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_example_hdvila", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_hdvila, + ) +) + + +video2world_7b_example_cosmos_nemo_assets = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_example_cosmos_nemo_assets", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets, + ) +) + +video2world_7b_example_cosmos_nemo_assets_4gpu_80gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_example_cosmos_nemo_assets_4gpu_80gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict( + pixel_chunk_duration=num_frames_4gpu_80gb, + spatial_resolution="384", + ), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_4gpu_80gb, + ) +) + +video2world_7b_example_cosmos_nemo_assets_8gpu_40gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_example_cosmos_nemo_assets_8gpu_40gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, # set to False to save memory + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 48, # Latent height dim + 48, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=False, # turn off to save memory + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict( + pixel_chunk_duration=num_frames_8gpu_40gb, + spatial_resolution="384", + ), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_8gpu_40gb, + ) +) + +video2world_7b_example_cosmos_nemo_assets_4gpu_40gb = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_example_cosmos_nemo_assets_4gpu_40gb", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, # set to False to save memory + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 24, # Latent height dim + 24, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=False, # turn off to save memory + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict( + pixel_chunk_duration=num_frames_4gpu_40gb, + spatial_resolution="192", + ), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_4gpu_40gb, + ) +) + +video2world_7b_lora_example_cosmos_nemo_assets = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "peft"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_7b_lora_example_cosmos_nemo_assets", + ), + optimizer=dict( + lr=1e-4, + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=1000, + broadcast_via_filesystem=True, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + async_saving=False, # set to False to save memory + ), + trainer=dict( + max_iter=5000, + distributed_parallelism="ddp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=False, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=4, + ), + model=dict( + peft_control=get_fa_ca_qv_lora_config(first_nblocks=28, rank=8, scale=1), + latent_shape=[ + 16, + 16, + 88, + 160, + ], + loss_reduce="mean", + ema=dict( + enabled=False, # turn off to save memory + ), + fsdp_enabled=False, + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(PEFTExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + scheduler=dict( + warm_up_steps=[0], + ), + dataloader_train=dataloader_train_cosmos_nemo_assets_480_848, + dataloader_val=dataloader_val_cosmos_nemo_assets_480_848, + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + video2world_7b_example_hdvila, + video2world_7b_example_cosmos_nemo_assets, + video2world_7b_example_cosmos_nemo_assets_4gpu_80gb, + video2world_7b_example_cosmos_nemo_assets_8gpu_40gb, + video2world_7b_example_cosmos_nemo_assets_4gpu_40gb, + video2world_7b_lora_example_cosmos_nemo_assets, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/video2world/registry.py b/cosmos_predict1/diffusion/training/config/video2world/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5b49f93e8b6bb932a7755e5bf8585104916e27 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world/registry.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.config.video2world.experiment import register_experiments + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/config/video2world_action/experiment.py b/cosmos_predict1/diffusion/training/config/video2world_action/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..522fb093ca4a8dcde618aaddb0e8a821634c1a36 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_action/experiment.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_3D import Dataset_3D +from cosmos_predict1.diffusion.training.models.extend_model import FSDPExtendDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_action import ActionConditionalVideoExtendGeneralDIT +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +cs = ConfigStore.instance() +base_path = "datasets/bridge/" +train_annotation_path = os.path.join(base_path, "annotation/train") +val_annotation_path = os.path.join(base_path, "annotation/val") +test_annotation_path = os.path.join(base_path, "annotation/test") + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +bridge_train_dataset = L(Dataset_3D)( + train_annotation_path=train_annotation_path, + val_annotation_path=val_annotation_path, + test_annotation_path=test_annotation_path, + video_path=base_path, + sequence_interval=1, + num_frames=2, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="train", + load_action=True, + load_t5_embeddings=False, +) + +bridge_val_dataset = L(Dataset_3D)( + train_annotation_path=train_annotation_path, + val_annotation_path=val_annotation_path, + test_annotation_path=test_annotation_path, + video_path=base_path, + sequence_interval=1, + num_frames=2, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="val", + load_action=True, + load_t5_embeddings=False, +) + + +dataloader_train = L(DataLoader)( + dataset=bridge_train_dataset, + sampler=L(get_sampler)(dataset=bridge_train_dataset), + batch_size=8, + drop_last=True, + pin_memory=True, + num_workers=8, +) +dataloader_val = L(DataLoader)( + dataset=bridge_val_dataset, + sampler=L(get_sampler)(dataset=bridge_val_dataset), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + + +video2world_action_bridge_2frames = LazyDict( # This experiment is used to verify the expanded config is the same as BASE002_101_512N_FSDP_LR-143_VideoImage_1-1 + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "action_conditional_video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world_action", + name="video2world_action_bridge_2frames", + ), + optimizer=dict( + lr=4e-4, + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=500, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2_000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + # Use 16x2x32x40 latent shape for training + latent_shape=[ + 16, # Latent channel dim + 2, # Latent temporal dim + 32, # Latent height dim + 40, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=False, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(ActionConditionalVideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + first_random_n_num_condition_t_max=1, + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + ) + ), + # Use Image VAE for training + vae=dict(pixel_chunk_duration=1), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps~(when resume from 310000) + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train, + dataloader_val=dataloader_val, + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + video2world_action_bridge_2frames, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/video2world_action/registry.py b/cosmos_predict1/diffusion/training/config/video2world_action/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c0ba551a3297c923ca69e6a0dcdf2c43f391b5 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_action/registry.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Optional + +import attrs +import torch +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.conditioner import VideoExtendCondition, VideoExtendConditioner +from cosmos_predict1.diffusion.config.base.conditioner import ( + FPSConfig, + ImageSizeConfig, + NumFramesConfig, + PaddingMaskConfig, + ReMapkey, + TextConfig, + VideoCondBoolConfig, +) +from cosmos_predict1.diffusion.training.config.video2world_action.experiment import register_experiments +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@dataclass +class ActionConditionalVideoExtendCondition(VideoExtendCondition): + action: Optional[torch.Tensor] = None + + +class ActionConditionalVideoExtendConditioner(VideoExtendConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> ActionConditionalVideoExtendCondition: + output = super()._forward(batch, override_dropout_rate) + assert "action" in batch, "ActionConditionalVideoExtendConditioner requires 'action' in batch" + output["action"] = batch["action"] + return ActionConditionalVideoExtendCondition(**output) + + +@attrs.define(slots=False) +class ActionConfig: + """ + Remap the key from the input dictionary to the output dictionary. For `action`. + """ + + obj: LazyDict = L(ReMapkey)(output_key="action", dtype=None) + dropout_rate: float = 0.0 + input_key: str = "action" + + +ActionConditionalVideoExtendConditionerConfig: LazyDict = L(ActionConditionalVideoExtendConditioner)( + text=TextConfig(), + fps=FPSConfig(), + num_frames=NumFramesConfig(), + image_size=ImageSizeConfig(), + padding_mask=PaddingMaskConfig(), + video_cond_bool=VideoCondBoolConfig(), + action=ActionConfig(), +) + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) + + cs.store( + group="conditioner", + package="model.conditioner", + name="action_conditional_video_cond", + node=ActionConditionalVideoExtendConditionerConfig, + ) diff --git a/cosmos_predict1/diffusion/training/config/video2world_instruction/experiment.py b/cosmos_predict1/diffusion/training/config/video2world_instruction/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..f33bc5cd87bec43be6ea668814f6c60b5ab4b645 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_instruction/experiment.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_3D import Dataset_3D +from cosmos_predict1.diffusion.training.models.extend_model import FSDPExtendDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_lvg import VideoExtendGeneralDIT +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +cs = ConfigStore.instance() +base_path = "datasets/bridge/" +train_annotation_path = os.path.join(base_path, "annotation/train") +val_annotation_path = os.path.join(base_path, "annotation/val") +test_annotation_path = os.path.join(base_path, "annotation/test") + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +bridge_train_dataset = L(Dataset_3D)( + train_annotation_path=train_annotation_path, + val_annotation_path=val_annotation_path, + test_annotation_path=test_annotation_path, + video_path=base_path, + sequence_interval=1, + num_frames=57, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="train", + load_action=False, + load_t5_embeddings=True, +) + +bridge_val_dataset = L(Dataset_3D)( + train_annotation_path=train_annotation_path, + val_annotation_path=val_annotation_path, + test_annotation_path=test_annotation_path, + video_path=base_path, + sequence_interval=1, + num_frames=57, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="val", + load_action=False, + load_t5_embeddings=True, +) + + +dataloader_train = L(DataLoader)( + dataset=bridge_train_dataset, + sampler=L(get_sampler)(dataset=bridge_train_dataset), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) +dataloader_val = L(DataLoader)( + dataset=bridge_val_dataset, + sampler=L(get_sampler)(dataset=bridge_val_dataset), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, +) + + +video2world_instruction_bridge_57frames = LazyDict( # This experiment is used to verify the expanded config is the same as BASE002_101_512N_FSDP_LR-143_VideoImage_1-1 + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world_instruction", + name="video2world_instruction_bridge_57frames", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=500, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2_000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + # Use 16x8x32x40 latent shape for training + latent_shape=[ + 16, # Latent channel dim + 8, # Latent temporal dim + 32, # Latent height dim + 40, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=False, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + # Use Image VAE for training + vae=dict(pixel_chunk_duration=57), + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + first_random_n_num_condition_t_max=1, + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + ) + ), + ), + # using the video extend model for training + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps~(when resume from 310000) + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train, + dataloader_val=dataloader_val, + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + video2world_instruction_bridge_57frames, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/video2world_instruction/registry.py b/cosmos_predict1/diffusion/training/config/video2world_instruction/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..bea45494ca7826258e9f9629dc480c52cfbbdec5 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_instruction/registry.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.config.video2world_instruction.experiment import register_experiments + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/config/video2world_multiview/experiment.py b/cosmos_predict1/diffusion/training/config/video2world_multiview/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d6e2245873867f455e38654b2edab069b034bf --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_multiview/experiment.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_multiview import Dataset +from cosmos_predict1.diffusion.training.models.extend_model_multiview import FSDPExtendDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_lvg_multiview import VideoExtendMultiviewGeneralDIT +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +num_frames = 57 +num_views = 5 +view_keys = ["pinhole_front_left", "pinhole_front", "pinhole_front_right", "pinhole_side_left", "pinhole_side_right"] +example_multiview_dataset_waymo = L(Dataset)( + dataset_dir="datasets/waymo", + sequence_interval=1, + num_frames=num_frames, + view_keys=view_keys, + video_size=(480, 848), +) + + +video2world_multiview_7b_example_waymo = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_video2world", + name="video2world_multiview_7b_example_waymo", + ), + optimizer=dict( + lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + # broadcast_via_filesystem=True, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=200, + hit_thres=5, + ), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + n_views=num_views, + # Use 16x16x32x40 latent shape for training + latent_shape=[ + 16, # Latent channel dim + 16, # Latent temporal dim + 88, # Latent height dim + 160, # Latent width dim + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendMultiviewGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + n_views=num_views, + ), + conditioner=dict( + video_cond_bool=dict( + condition_location="first_random_n", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ) + ), + vae=dict(pixel_chunk_duration=num_frames), + ), + model_obj=L(FSDPExtendDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps~(when resume from 310000) + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=L(DataLoader)( + dataset=example_multiview_dataset_waymo, + sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, + ), + dataloader_val=L(DataLoader)( + dataset=example_multiview_dataset_waymo, + sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), + batch_size=1, + drop_last=True, + pin_memory=True, + num_workers=8, + ), + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + video2world_multiview_7b_example_waymo, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/video2world_multiview/registry.py b/cosmos_predict1/diffusion/training/config/video2world_multiview/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..b30c4cdc00f18675329364a4f64679578dbf8c97 --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/video2world_multiview/registry.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.config.video2world_multiview.experiment import register_experiments + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/config/world_interpolator/experiment.py b/cosmos_predict1/diffusion/training/config/world_interpolator/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..92ad62cd053b2dba6daa0b02fe89b0b5c063e87a --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/world_interpolator/experiment.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore +from megatron.core import parallel_state +from torch.utils.data import DataLoader, DistributedSampler + +from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed +from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback +from cosmos_predict1.diffusion.training.datasets.dataset_video import Dataset +from cosmos_predict1.diffusion.training.models.interpolator import FSDPInterpolatorDiffusionModel +from cosmos_predict1.diffusion.training.networks.general_dit_lvg import VideoExtendGeneralDIT +from cosmos_predict1.utils import log +from cosmos_predict1.utils.callback import ProgressBarCallback +from cosmos_predict1.utils.callbacks.grad_clip import GradClip +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +def get_sampler(dataset): + return DistributedSampler( + dataset, + num_replicas=parallel_state.get_data_parallel_world_size(), + rank=parallel_state.get_data_parallel_rank(), + shuffle=True, + seed=0, + ) + + +cs = ConfigStore.instance() + +num_frames = 18 +example_video_dataset = L(Dataset)( + dataset_dir="datasets/hdvila", + sequence_interval=1, + num_frames=num_frames, + video_size=(720, 1280), + start_frame_interval=1, +) + +dataloader_train = L(DataLoader)( + dataset=example_video_dataset, + sampler=L(get_sampler)(dataset=example_video_dataset), + batch_size=1, + drop_last=True, +) +dataloader_val = L(DataLoader)( + dataset=example_video_dataset, + sampler=L(get_sampler)(dataset=example_video_dataset), + batch_size=1, + drop_last=True, +) + + +world_interpolator_7b_example_hdvila = LazyDict( + dict( + defaults=[ + {"override /net": "faditv2_7b"}, + {"override /conditioner": "video_cond"}, + {"override /ckpt_klass": "fsdp"}, + {"override /checkpoint": "local"}, + {"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, + "_self_", + ], + job=dict( + project="posttraining", + group="diffusion_world_interpolator", + name="world_interpolator_7b_example_hdvila", + ), + optimizer=dict( + # lr=2 ** (-14.3), # 2**(-14.3) approx 5e-5 + lr=0.0, + weight_decay=0.1, + betas=[0.9, 0.99], + eps=1e-10, + ), + checkpoint=dict( + save_iter=200, + # save_iter=1, + broadcast_via_filesystem=False, + load_path="checkpoints/Cosmos-Predict1-7B-WorldInterpolator/model.pt", + load_training_state=False, + strict_resume=False, + keys_not_to_resume=[], + ), + trainer=dict( + max_iter=2000, + # max_iter=2, + distributed_parallelism="fsdp", + logging_iter=200, + callbacks=dict( + grad_clip=L(GradClip)( + model_key="model", + fsdp_enabled=True, + ), + low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), + iter_speed=L(IterSpeed)( + every_n=10, + hit_thres=0, + ), + progress_bar=L(ProgressBarCallback)(), + ), + ), + model_parallel=dict( + sequence_parallel=False, + tensor_model_parallel_size=1, + context_parallel_size=1, + ), + model=dict( + latent_shape=[ + 16, + 4, + 88, + 160, + ], + loss_reduce="mean", + ema=dict( + enabled=True, + ), + fsdp_enabled=True, + fsdp=dict( + policy="block", + # checkpoint=False, + checkpoint=True, + min_num_params=1024, + sharding_group_size=32, + sharding_strategy="hybrid", + ), + net=L(VideoExtendGeneralDIT)( + rope_h_extrapolation_ratio=1, + rope_w_extrapolation_ratio=1, + rope_t_extrapolation_ratio=2, + ), + adjust_video_noise=True, + context_parallel_size=1, + num_latents_to_drop=1, + conditioner=dict( + video_cond_bool=dict( + condition_location="first_and_last_1", + cfg_unconditional_type="zero_condition_region_condition_mask", + apply_corruption_to_condition_region="noise_with_sigma", + condition_on_augment_sigma=False, + dropout_rate=0.0, # No dropout + first_random_n_num_condition_t_max=2, + normalize_condition_latent=False, + # Let the augment sigma mostly fall in the range of 0 to 0.3 + augment_sigma_sample_p_mean=-3.0, + augment_sigma_sample_p_std=2.0, + augment_sigma_sample_multiplier=1.0, + ), + text=dict( + dropout_rate=0.5, + ), + ), + vae=dict(pixel_chunk_duration=9), # 9 frames per chunk for video vae (18 frames / 2 chunks = 9) + ), + model_obj=L(FSDPInterpolatorDiffusionModel)( + config=PLACEHOLDER, + fsdp_checkpointer=PLACEHOLDER, + ), + # warming up for first 2500 steps~(when resume from 310000) + scheduler=dict( + warm_up_steps=[2500], + cycle_lengths=[10000000000000], + f_start=[1.0e-6], + f_max=[1.0], + f_min=[1.0], + ), + dataloader_train=dataloader_train, + ) +) + + +def register_experiments(cs): + # Register the experiments + for _item in [ + world_interpolator_7b_example_hdvila, + ]: + experiment_name = _item["job"]["name"] + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) diff --git a/cosmos_predict1/diffusion/training/config/world_interpolator/registry.py b/cosmos_predict1/diffusion/training/config/world_interpolator/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..601f2e4da4bf89586d26f28e6cb1bd1103ca662b --- /dev/null +++ b/cosmos_predict1/diffusion/training/config/world_interpolator/registry.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.diffusion.training.config.world_interpolator.experiment import register_experiments + + +def register_configs(): + cs = ConfigStore.instance() + + register_experiments(cs) diff --git a/cosmos_predict1/diffusion/training/context_parallel.py b/cosmos_predict1/diffusion/training/context_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..4758db33f1d63546d01be917fa19a821ed4bc1f7 --- /dev/null +++ b/cosmos_predict1/diffusion/training/context_parallel.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import Tensor +from torch.distributed import ProcessGroup, all_gather, get_process_group_ranks, get_world_size + + +def split_inputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Split input tensor along the sequence dimension for checkpoint parallelism. + + This function divides the input tensor into equal parts along the specified + sequence dimension, based on the number of ranks in the checkpoint parallelism group. + It then selects the part corresponding to the current rank. + + Args: + x: Input tensor to be split. + seq_dim: The dimension along which to split the input (sequence dimension). + cp_group: The process group for checkpoint parallelism. + + Returns: + A slice of the input tensor corresponding to the current rank. + + Raises: + AssertionError: If the sequence dimension is not divisible by the number of ranks. + """ + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + + assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" + x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) + seq_idx = torch.tensor([cp_group.rank()], device=x.device) + x = x.index_select(seq_dim, seq_idx) + # Note that the new sequence length is the original sequence length / cp_size + x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + return x + + +def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Concatenate outputs from different ranks in the checkpoint parallelism group. + + This function gathers tensors from all ranks in the checkpoint parallelism group + and concatenates them along the specified sequence dimension. + + Args: + x: Input tensor to be concatenated. + seq_dim: The dimension along which to concatenate the tensors (sequence dimension). + cp_group: The process group for checkpoint parallelism. + + Returns: + A tensor that is the concatenation of tensors from all ranks in the cp_group. + + Raises: + RuntimeError: If the gather operation fails. + """ + # Get the world size (number of processes in the group) + world_size = get_world_size(cp_group) + + # Create a list to store tensors from all ranks + gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] + + # Gather tensors from all ranks + try: + all_gather(gathered_tensors, x, group=cp_group) + except RuntimeError as e: + raise RuntimeError(f"Failed to gather tensors: {e}") + + # Concatenate the gathered tensors along the specified dimension + return torch.cat(gathered_tensors, dim=seq_dim) + + +def cat_outputs_cp_with_grad(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: + """ + Concatenate outputs from different ranks in the context parallelism group. + + This function gathers tensors from all ranks in the checkpoint parallelism group + and concatenates them along the specified sequence dimension. + + It retains computational graph locally for each rank by replacing the portion of the tensor with original output. + + Args: + x: Input tensor to be concatenated. + seq_dim: The dimension along which to concatenate the tensors (sequence dimension). + cp_group: The process group for checkpoint parallelism. + + Returns: + A tensor that is the concatenation of tensors from all ranks in the cp_group. + + Raises: + RuntimeError: If the gather operation fails. + """ + # Get the world size (number of processes in the group) + cp_size = cp_group.size() + assert cp_size > 0, "cp_size should be greater than 0" + + # Create a list to store tensors from all ranks + gathered_tensors = [torch.zeros_like(x) for _ in range(cp_size)] + + # Gather tensors from all ranks + try: + all_gather(gathered_tensors, x, group=cp_group) + except RuntimeError as e: + raise RuntimeError(f"Failed to gather tensors: {e}") + + rank = cp_group.rank() + gathered_tensors[rank] = x + # Concatenate the gathered tensors along the specified dimension + return torch.cat(gathered_tensors, dim=seq_dim) diff --git a/cosmos_predict1/diffusion/training/datasets/data_sources/item_dataset.py b/cosmos_predict1/diffusion/training/datasets/data_sources/item_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4dcd91920a9ab1364af8a2a482495cb4bec0f9 --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/data_sources/item_dataset.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses + + +@dataclasses.dataclass +class ItemDatasetConfig: + path: str + length: int diff --git a/cosmos_predict1/diffusion/training/datasets/dataset_3D.py b/cosmos_predict1/diffusion/training/datasets/dataset_3D.py new file mode 100644 index 0000000000000000000000000000000000000000..e561400a34ba4e504ee717a1edabdf387108bf9a --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/dataset_3D.py @@ -0,0 +1,420 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/diffusion/posttrain/datasets/dataset_3D.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import json +import os +import pickle +import random +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed + +import imageio +import numpy as np +import torch +from decord import VideoReader, cpu +from einops import rearrange +from torch.utils.data import Dataset +from torchvision import transforms as T +from tqdm import tqdm + +from cosmos_predict1.diffusion.training.datasets.dataset_utils import ( + Resize_Preprocess, + ToTensorVideo, + euler2rotm, + rotm2euler, +) + + +class Dataset_3D(Dataset): + def __init__( + self, + train_annotation_path, + val_annotation_path, + test_annotation_path, + video_path, + sequence_interval, + num_frames, + cam_ids, + accumulate_action, + video_size, + val_start_frame_interval, + debug=False, + normalize=False, + pre_encode=False, + do_evaluate=False, + load_t5_embeddings=False, + load_action=True, + mode="train", + ): + """Dataset class for loading 3D robot action-conditional data. + + This dataset loads robot trajectories consisting of RGB video frames, robot states (arm positions and gripper states), + and computes relative actions between consecutive frames. + + Args: + train_annotation_path (str): Path to training annotation files + val_annotation_path (str): Path to validation annotation files + test_annotation_path (str): Path to test annotation files + video_path (str): Base path to video files + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + cam_ids (list): List of camera IDs to sample from + accumulate_action (bool): Whether to accumulate actions relative to first frame + video_size (list): Target size [H,W] for video frames + val_start_frame_interval (int): Frame sampling interval for validation/test + debug (bool, optional): If True, only loads subset of data. Defaults to False. + normalize (bool, optional): Whether to normalize video frames. Defaults to False. + pre_encode (bool, optional): Whether to pre-encode video frames. Defaults to False. + do_evaluate (bool, optional): Whether in evaluation mode. Defaults to False. + load_t5_embeddings (bool, optional): Whether to load T5 embeddings. Defaults to False. + load_action (bool, optional): Whether to load actions. Defaults to True. + mode (str, optional): Dataset mode - 'train', 'val' or 'test'. Defaults to 'train'. + + The dataset loads robot trajectories and computes: + - RGB video frames from specified camera views + - Robot arm states (xyz position + euler angles) + - Gripper states (binary open/closed) + - Relative actions between consecutive frames + + Actions are computed as relative transforms between frames: + - Translation: xyz offset in previous frame's coordinate frame + - Rotation: euler angles of relative rotation + - Gripper: binary gripper state + + Returns dict with: + - video: RGB frames tensor [T,C,H,W] + - action: Action tensor [T-1,7] + - video_name: Dict with episode/frame metadata + - latent: Pre-encoded video features if pre_encode=True + """ + + super().__init__() + if mode == "train": + self.data_path = train_annotation_path + self.start_frame_interval = 1 + elif mode == "val": + self.data_path = val_annotation_path + self.start_frame_interval = val_start_frame_interval + elif mode == "test": + self.data_path = test_annotation_path + self.start_frame_interval = val_start_frame_interval + self.video_path = video_path + self.sequence_interval = sequence_interval + self.mode = mode + self.sequence_length = num_frames + self.normalize = normalize + self.pre_encode = pre_encode + self.load_t5_embeddings = load_t5_embeddings + self.load_action = load_action + + self.cam_ids = cam_ids + self.accumulate_action = accumulate_action + + self.action_dim = 7 # ee xyz (3) + ee euler (3) + gripper(1) + self.c_act_scaler = [20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 1.0] + self.c_act_scaler = np.array(self.c_act_scaler, dtype=float) + self.ann_files = self._init_anns(self.data_path) + + self.samples = self._init_sequences(self.ann_files) + + self.samples = sorted(self.samples, key=lambda x: (x["ann_file"], x["frame_ids"][0])) + if debug and not do_evaluate: + self.samples = self.samples[0:10] + self.wrong_number = 0 + self.transform = T.Compose([T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]) + self.training = False + self.preprocess = T.Compose( + [ + ToTensorVideo(), + Resize_Preprocess(tuple(video_size)), # 288 512 + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + self.not_norm_preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) + + def __str__(self): + return f"{len(self.ann_files)} samples from {self.data_path}" + + def _init_anns(self, data_dir): + ann_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".json")] + return ann_files + + def _init_sequences(self, ann_files): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_ann_file = { + executor.submit(self._load_and_process_ann_file, ann_file): ann_file for ann_file in ann_files + } + for future in tqdm(as_completed(future_to_ann_file), total=len(ann_files)): + samples.extend(future.result()) + return samples + + def _load_and_process_ann_file(self, ann_file): + samples = [] + with open(ann_file, "r") as f: + ann = json.load(f) + + n_frames = len(ann["state"]) + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["ann_file"] = ann_file + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all() + assert (np.array(frame_ids) >= 0).all() + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + return frame_data + + def _get_frames(self, label, frame_ids, cam_id, pre_encode): + if pre_encode: + raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") + else: + video_path = label["videos"][cam_id]["video_path"] + video_path = os.path.join(self.video_path, video_path) + frames = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) + + def printvideo(videos, filename): + t_videos = rearrange(videos, "f c h w -> f h w c") + t_videos = ( + ((t_videos / 2.0 + 0.5).clamp(0, 1) * 255).detach().to(dtype=torch.uint8).cpu().contiguous().numpy() + ) + print(t_videos.shape) + writer = imageio.get_writer(filename, fps=4) # fps 是帧率 + for frame in t_videos: + writer.append_data(frame) # 1 4 13 23 # fp16 24 76 456 688 + + if self.normalize: + frames = self.preprocess(frames) + else: + frames = self.not_norm_preprocess(frames) + frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) + return frames + + def _get_obs(self, label, frame_ids, cam_id, pre_encode): + if cam_id is None: + temp_cam_id = random.choice(self.cam_ids) + else: + temp_cam_id = cam_id + frames = self._get_frames(label, frame_ids, cam_id=temp_cam_id, pre_encode=pre_encode) + return frames, temp_cam_id + + def _get_robot_states(self, label, frame_ids): + all_states = np.array(label["state"]) + all_cont_gripper_states = np.array(label["continuous_gripper_state"]) + states = all_states[frame_ids] + cont_gripper_states = all_cont_gripper_states[frame_ids] + arm_states = states[:, :6] + assert arm_states.shape[0] == self.sequence_length + assert cont_gripper_states.shape[0] == self.sequence_length + return arm_states, cont_gripper_states + + def _get_all_robot_states(self, label, frame_ids): + all_states = np.array(label["state"]) + all_cont_gripper_states = np.array(label["continuous_gripper_state"]) + states = all_states[frame_ids] + cont_gripper_states = all_cont_gripper_states[frame_ids] + arm_states = states[:, :6] + return arm_states, cont_gripper_states + + def _get_all_actions(self, arm_states, gripper_states, accumulate_action): + action_num = arm_states.shape[0] - 1 + action = np.zeros((action_num, self.action_dim)) + if accumulate_action: + first_xyz = arm_states[0, 0:3] + first_rpy = arm_states[0, 3:6] + first_rotm = euler2rotm(first_rpy) + for k in range(1, action_num + 1): + curr_xyz = arm_states[k, 0:3] + curr_rpy = arm_states[k, 3:6] + curr_gripper = gripper_states[k] + curr_rotm = euler2rotm(curr_rpy) + rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) + rel_rotm = first_rotm.T @ curr_rotm + rel_rpy = rotm2euler(rel_rotm) + action[k - 1, 0:3] = rel_xyz + action[k - 1, 3:6] = rel_rpy + action[k - 1, 6] = curr_gripper + else: + for k in range(1, action_num + 1): + prev_xyz = arm_states[k - 1, 0:3] + prev_rpy = arm_states[k - 1, 3:6] + prev_rotm = euler2rotm(prev_rpy) + curr_xyz = arm_states[k, 0:3] + curr_rpy = arm_states[k, 3:6] + curr_gripper = gripper_states[k] + curr_rotm = euler2rotm(curr_rpy) + rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz) + rel_rotm = prev_rotm.T @ curr_rotm + rel_rpy = rotm2euler(rel_rotm) + action[k - 1, 0:3] = rel_xyz + action[k - 1, 3:6] = rel_rpy + action[k - 1, 6] = curr_gripper + return torch.from_numpy(action) # (l - 1, act_dim) + + def _get_actions(self, arm_states, gripper_states, accumulate_action): + action = np.zeros((self.sequence_length - 1, self.action_dim)) + if accumulate_action: + first_xyz = arm_states[0, 0:3] + first_rpy = arm_states[0, 3:6] + first_rotm = euler2rotm(first_rpy) + for k in range(1, self.sequence_length): + curr_xyz = arm_states[k, 0:3] + curr_rpy = arm_states[k, 3:6] + curr_gripper = gripper_states[k] + curr_rotm = euler2rotm(curr_rpy) + rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) + rel_rotm = first_rotm.T @ curr_rotm + rel_rpy = rotm2euler(rel_rotm) + action[k - 1, 0:3] = rel_xyz + action[k - 1, 3:6] = rel_rpy + action[k - 1, 6] = curr_gripper + else: + for k in range(1, self.sequence_length): + prev_xyz = arm_states[k - 1, 0:3] + prev_rpy = arm_states[k - 1, 3:6] + prev_rotm = euler2rotm(prev_rpy) + curr_xyz = arm_states[k, 0:3] + curr_rpy = arm_states[k, 3:6] + curr_gripper = gripper_states[k] + curr_rotm = euler2rotm(curr_rpy) + rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz) + rel_rotm = prev_rotm.T @ curr_rotm + rel_rpy = rotm2euler(rel_rotm) + action[k - 1, 0:3] = rel_xyz + action[k - 1, 3:6] = rel_rpy + action[k - 1, 6] = curr_gripper + return torch.from_numpy(action) # (l - 1, act_dim) + + def __getitem__(self, index, cam_id=None, return_video=False): + if self.mode != "train": + np.random.seed(index) + random.seed(index) + + try: + sample = self.samples[index] + ann_file = sample["ann_file"] + frame_ids = sample["frame_ids"] + with open(ann_file, "r") as f: + label = json.load(f) + arm_states, gripper_states = self._get_robot_states(label, frame_ids) + actions = self._get_actions(arm_states, gripper_states, self.accumulate_action) + actions *= self.c_act_scaler + + data = dict() + if self.load_action: + data["action"] = actions.float() + + if self.pre_encode: + raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") + else: + video, cam_id = self._get_obs(label, frame_ids, cam_id, pre_encode=False) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + data["video"] = video.to(dtype=torch.uint8) + + data["annotation_file"] = ann_file + + # NOTE: __key__ is used to uniquely identify the sample, required for callback functions + if "episode_id" in label: + data["__key__"] = label["episode_id"] + else: + data["__key__"] = label["original_path"] + + # Just add these to fit the interface + if self.load_t5_embeddings: + t5_embedding_path = ann_file.replace(".json", ".pickle") + with open(t5_embedding_path, "rb") as f: + data["t5_text_embeddings"] = torch.from_numpy(pickle.load(f)[0]) + else: + data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16) + data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) + data["fps"] = 4 + data["image_size"] = 256 * torch.ones(4) # TODO: Does this matter? + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 256, 256) + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['ann_file']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset_3D( + train_annotation_path="datasets/bridge/annotation/train", + val_annotation_path="datasets/bridge/annotation/val", + test_annotation_path="datasets/bridge/annotation/test", + video_path="datasets/bridge/", + sequence_interval=1, + num_frames=2, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="train", + load_t5_embeddings=True, + ) + + indices = [0, 13, 200, -1] + for idx in indices: + print( + ( + f"{idx=} " + f"{dataset[idx]['video'].sum()=}\n" + f"{dataset[idx]['video'].shape=}\n" + f"{dataset[idx]['video_name']=}\n" + f"{dataset[idx]['action'].sum()=}\n" + "---" + ) + ) + + from IPython import embed + + embed() diff --git a/cosmos_predict1/diffusion/training/datasets/dataset_3D_binary.py b/cosmos_predict1/diffusion/training/datasets/dataset_3D_binary.py new file mode 100644 index 0000000000000000000000000000000000000000..c760c92c7f6531335c6194cace332c5664ab94c0 --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/dataset_3D_binary.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/diffusion/posttrain/datasets/dataset_3D.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import json +import pickle +import random +import traceback +import warnings + +import numpy as np +import torch + +from cosmos_predict1.diffusion.training.datasets.dataset_3D import Dataset_3D +from cosmos_predict1.utils import log + + +class Dataset_3DBinary(Dataset_3D): + def __init__( + self, + train_annotation_path, + val_annotation_path, + test_annotation_path, + video_path, + sequence_interval, + num_frames, + cam_ids, + accumulate_action, + video_size, + val_start_frame_interval, + debug=False, + normalize=False, + pre_encode=False, + do_evaluate=False, + load_t5_embeddings=False, + load_action=True, + mode="train", + ): + """Dataset class for loading 3D robot action-conditional data. + + This dataset loads robot trajectories consisting of RGB video frames, robot states + (arm positions and binary gripper states), and computes relative actions between + consecutive frames. + """ + + super().__init__( + train_annotation_path=train_annotation_path, + val_annotation_path=val_annotation_path, + test_annotation_path=test_annotation_path, + video_path=video_path, + sequence_interval=sequence_interval, + num_frames=num_frames, + cam_ids=cam_ids, + accumulate_action=accumulate_action, + video_size=video_size, + val_start_frame_interval=val_start_frame_interval, + debug=debug, + normalize=normalize, + pre_encode=pre_encode, + do_evaluate=do_evaluate, + load_t5_embeddings=load_t5_embeddings, + load_action=load_action, + mode=mode, + ) + + log.info("Dataset_3DBinary: in this dataset, we binarize the gripper state to 0 or 1.") + + def _get_json_action(self, label, frame_ids): + all_action = np.array(label["action"]) + actions = all_action[frame_ids[:-1]] + return torch.from_numpy(actions) + + def __getitem__(self, index, cam_id=None, return_video=False): + if self.mode != "train": + np.random.seed(index) + random.seed(index) + + try: + sample = self.samples[index] + ann_file = sample["ann_file"] + frame_ids = sample["frame_ids"] + with open(ann_file, "r") as f: + label = json.load(f) + arm_states, gripper_states = self._get_robot_states(label, frame_ids) + actions = self._get_actions(arm_states, gripper_states, self.accumulate_action) + actions *= self.c_act_scaler + + data = dict() + if self.load_action: + data["action"] = actions.float() + json_action = self._get_json_action(label, frame_ids).float() + json_action[:, :6] = data["action"][:, :6] + data["action"] = json_action + + if self.pre_encode: + raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") + else: + video, cam_id = self._get_obs(label, frame_ids, cam_id, pre_encode=False) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + data["video"] = video.to(dtype=torch.uint8) + + data["annotation_file"] = ann_file + + if "episode_id" in label: + data["__key__"] = label["episode_id"] + else: + data["__key__"] = label["original_path"] + + # Just add these to fit the interface + if self.load_t5_embeddings: + t5_embedding_path = ann_file.replace(".json", ".pickle") + with open(t5_embedding_path, "rb") as f: + data["t5_text_embeddings"] = torch.from_numpy(pickle.load(f)[0]) + else: + data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16) + data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) + data["fps"] = 4 + data["image_size"] = 256 * torch.ones(4) # TODO: Does this matter? + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 256, 256) + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['ann_file']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset_3DBinary( + train_annotation_path="datasets/bridge/annotation/train", + val_annotation_path="datasets/bridge/annotation/val", + test_annotation_path="datasets/bridge/annotation/test", + video_path="datasets/bridge/", + sequence_interval=1, + num_frames=2, + cam_ids=[0], + accumulate_action=False, + video_size=[256, 320], + val_start_frame_interval=1, + mode="train", + load_t5_embeddings=True, + ) + + indices = [0, 13, 200, -1] + for idx in indices: + print( + ( + f"{idx=} " + f"{dataset[idx]['video'].sum()=}\n" + f"{dataset[idx]['video'].shape=}\n" + f"{dataset[idx]['video_name']=}\n" + f"{dataset[idx]['action'].sum()=}\n" + f"{dataset[idx]['json_action'].sum()=}\n" + "---" + ) + ) + + from IPython import embed + + embed() diff --git a/cosmos_predict1/diffusion/training/datasets/dataset_multiview.py b/cosmos_predict1/diffusion/training/datasets/dataset_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..6f86e8a8e5c0518afa193b13ff8a9c01b37ed5f3 --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/dataset_multiview.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/diffusion/training/datasets/dataset_multiview.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import os +import pickle +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +import torch +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from torchvision import transforms as T +from tqdm import tqdm + +from cosmos_predict1.diffusion.training.datasets.dataset_utils import Resize_Preprocess, ToTensorVideo + + +class Dataset(Dataset): + def __init__( + self, + dataset_dir, + sequence_interval, + num_frames, + view_keys, + video_size, + start_frame_interval=1, + ): + """Dataset class for loading image-text-to-video generation data. + + Args: + dataset_dir (str): Base path to the dataset directory + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + video_size (list): Target size [H,W] for video frames + + Returns dict with: + - video: RGB frames tensor [T,C,H,W] + - video_name: Dict with episode/frame metadata + """ + + super().__init__() + self.dataset_dir = dataset_dir + self.start_frame_interval = start_frame_interval + self.sequence_interval = sequence_interval + self.sequence_length = num_frames + self.view_keys = view_keys + + video_dir = os.path.join(self.dataset_dir, "videos") + self.video_paths = [ + os.path.join(video_dir, view_keys[0], f) for f in os.listdir(os.path.join(video_dir, view_keys[0])) + ] + print(f"{len(self.video_paths)} videos in total") + + self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") + self.samples = self._init_samples(self.video_paths) + self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) + print(f"{len(self.samples)} samples in total") + self.wrong_number = 0 + self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) + + cache_dir = os.path.join(self.dataset_dir, "cache") + self.prefix_t5_embeddings = {} + for view_key in view_keys: + with open(os.path.join(cache_dir, f"prefix_t5_embeddings_{view_key}.pickle"), "rb") as f: + self.prefix_t5_embeddings[view_key] = pickle.load(f)[0] + + def __str__(self): + return f"{len(self.video_paths)} samples from {self.dataset_dir}" + + def _init_samples(self, video_paths): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_video_path = { + executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths + } + for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): + samples.extend(future.result()) + return samples + + def _load_and_process_video_path(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + n_frames = len(vr) + + samples = [] + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["video_path"] = video_path + sample["t5_embedding_path"] = os.path.join( + self.t5_dir, + os.path.basename(os.path.dirname(video_path)), + os.path.basename(video_path).replace(".mp4", ".pickle"), + ) + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all() + assert (np.array(frame_ids) >= 0).all() + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + try: + fps = vr.get_avg_fps() + except Exception: # failed to read FPS + fps = 24 + return frame_data, fps + + def _get_frames(self, video_path, frame_ids): + frames, fps = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) + frames = self.preprocess(frames) + frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) + return frames, fps + + def __getitem__(self, index): + try: + sample = self.samples[index] + video_path = sample["video_path"] + frame_ids = sample["frame_ids"] + t5_embedding_path = sample["t5_embedding_path"] + + data = dict() + + videos = [] + t5_embeddings = [] + for view_key in self.view_keys: + video, fps = self._get_frames( + os.path.join(os.path.dirname(os.path.dirname(video_path)), view_key, os.path.basename(video_path)), + frame_ids, + ) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + videos.append(video) + + with open( + os.path.join( + os.path.dirname(os.path.dirname(t5_embedding_path)), + view_key, + os.path.basename(t5_embedding_path), + ), + "rb", + ) as f: + t5_embedding = pickle.load(f)[0] + t5_embedding = np.concatenate([self.prefix_t5_embeddings[view_key], t5_embedding], axis=0) + t5_embedding = torch.from_numpy(t5_embedding) + if t5_embedding.shape[0] < 512: + t5_embedding = torch.cat([t5_embedding, torch.zeros(512 - t5_embedding.shape[0], 1024)], dim=0) + t5_embeddings.append(t5_embedding) + video = torch.cat(videos, dim=1) + t5_embedding = torch.cat(t5_embeddings, dim=0) + + data["video"] = video + data["video_name"] = { + "video_path": video_path, + "t5_embedding_path": t5_embedding_path, + "start_frame_id": str(frame_ids[0]), + } + data["t5_text_embeddings"] = t5_embedding + data["t5_text_mask"] = torch.ones(512 * len(self.view_keys), dtype=torch.int64) + data["fps"] = fps + data["image_size"] = torch.tensor([704, 1280, 704, 1280]) + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 704, 1280) + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset( + dataset_dir="datasets/waymo/", + sequence_interval=1, + num_frames=57, + view_keys=[ + "pinhole_front_left", + "pinhole_front", + "pinhole_front_right", + "pinhole_side_left", + "pinhole_side_right", + ], + video_size=[240, 360], + ) + + indices = [0, 13, 200, -1] + for idx in indices: + data = dataset[idx] + print( + ( + f"{idx=} " + f"{data['video'].sum()=}\n" + f"{data['video'].shape=}\n" + f"{data['video_name']=}\n" + f"{data['t5_text_embeddings'].shape=}\n" + "---" + ) + ) diff --git a/cosmos_predict1/diffusion/training/datasets/dataset_utils.py b/cosmos_predict1/diffusion/training/datasets/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..963e4c9de3d2e7958dbaa0526b284650671965b3 --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/dataset_utils.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_util.py +""" + +import base64 +import math +import os +from io import BytesIO + +import numpy as np +import torch +import torch.distributed as dist +import torchvision.transforms.functional as F +from PIL import Image + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def b64_2_img(data: str): + image_b64 = base64.b64decode(data) + img = Image.open(BytesIO(image_b64)).convert("RGB") + return img + + +def get_continuous_action(d_acts, c_act_max, c_act_min, n_bins): + c_act_max = c_act_max.to(d_acts.device) + c_act_min = c_act_min.to(d_acts.device) + c_acts = d_acts / (n_bins - 1) * (c_act_max - c_act_min) + c_act_min + return c_acts + + +def alpha2rotm(a): + """Alpha euler angle to rotation matrix.""" + rotm = np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]]) + return rotm + + +def beta2rotm(b): + """Beta euler angle to rotation matrix.""" + rotm = np.array([[np.cos(b), 0, np.sin(b)], [0, 1, 0], [-np.sin(b), 0, np.cos(b)]]) + return rotm + + +def gamma2rotm(c): + """Gamma euler angle to rotation matrix.""" + rotm = np.array([[np.cos(c), -np.sin(c), 0], [np.sin(c), np.cos(c), 0], [0, 0, 1]]) + return rotm + + +def euler2rotm(euler_angles): + """Euler angle (ZYX) to rotation matrix.""" + alpha = euler_angles[0] + beta = euler_angles[1] + gamma = euler_angles[2] + + rotm_a = alpha2rotm(alpha) + rotm_b = beta2rotm(beta) + rotm_c = gamma2rotm(gamma) + + rotm = rotm_c @ rotm_b @ rotm_a + + return rotm + + +def isRotm(R): + # Checks if a matrix is a valid rotation matrix. + # Forked from Andy Zeng + Rt = np.transpose(R) + shouldBeIdentity = np.dot(Rt, R) + I = np.identity(3, dtype=R.dtype) + n = np.linalg.norm(I - shouldBeIdentity) + return n < 1e-6 + + +def rotm2euler(R): + # Forked from: https://learnopencv.com/rotation-matrix-to-euler-angles/ + # R = Rz * Ry * Rx + assert isRotm(R) + sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0]) + singular = sy < 1e-6 + + if not singular: + x = math.atan2(R[2, 1], R[2, 2]) + y = math.atan2(-R[2, 0], sy) + z = math.atan2(R[1, 0], R[0, 0]) + else: + x = math.atan2(-R[1, 2], R[1, 1]) + y = math.atan2(-R[2, 0], sy) + z = 0 + + # (-pi , pi] + while x > np.pi: + x -= 2 * np.pi + while x <= -np.pi: + x += 2 * np.pi + while y > np.pi: + y -= 2 * np.pi + while y <= -np.pi: + y += 2 * np.pi + while z > np.pi: + z -= 2 * np.pi + while z <= -np.pi: + z += 2 * np.pi + return np.array([x, y, z]) + + +def get_converted_fp32_paths(deepspeed_ckpt_path): + deepspeed_ckpt_path = deepspeed_ckpt_path.rstrip("/") + ckpt_dir = os.path.dirname(deepspeed_ckpt_path) + ckpt_name = os.path.basename(deepspeed_ckpt_path) + fp32_ckpt_name = f"{ckpt_name}.fp32.pt" + converted_path = os.path.join(ckpt_dir, fp32_ckpt_name) + return converted_path + + +def quat2rotm(quat): + """Quaternion to rotation matrix. + + Args: + quat (4, numpy array): quaternion x, y, z, w + Returns: + rotm (3x3 numpy array): rotation matrix + """ + w = quat[3] + x = quat[0] + y = quat[1] + z = quat[2] + + s = w * w + x * x + y * y + z * z + + rotm = np.array( + [ + [1 - 2 * (y * y + z * z) / s, 2 * (x * y - z * w) / s, 2 * (x * z + y * w) / s], + [2 * (x * y + z * w) / s, 1 - 2 * (x * x + z * z) / s, 2 * (y * z - x * w) / s], + [2 * (x * z - y * w) / s, 2 * (y * z + x * w) / s, 1 - 2 * (x * x + y * y) / s], + ] + ) + + return rotm + + +class Resize_Preprocess: + def __init__(self, size): + """ + Initialize the preprocessing class with the target size. + Args: + size (tuple): The target height and width as a tuple (height, width). + """ + self.size = size + + def __call__(self, video_frames): + """ + Apply the transformation to each frame in the video. + Args: + video_frames (torch.Tensor): A tensor representing a batch of video frames. + Returns: + torch.Tensor: The transformed video frames. + """ + # Resize each frame in the video + resized_frames = torch.stack([F.resize(frame, self.size, antialias=True) for frame in video_frames]) + return resized_frames + + +class Preprocess: + def __init__(self, size): + self.size = size + + def __call__(self, clip): + clip = Preprocess.resize_scale(clip, self.size[0], self.size[1], interpolation_mode="bilinear") + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + @staticmethod + def resize_scale(clip, target_height, target_width, interpolation_mode): + target_ratio = target_height / target_width + H = clip.size(-2) + W = clip.size(-1) + clip_ratio = H / W + if clip_ratio > target_ratio: + scale_ = target_width / W + else: + scale_ = target_height / H + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True diff --git a/cosmos_predict1/diffusion/training/datasets/dataset_video.py b/cosmos_predict1/diffusion/training/datasets/dataset_video.py new file mode 100644 index 0000000000000000000000000000000000000000..d728129d6c9bb6d2e73dbba246b1ce5558cf6509 --- /dev/null +++ b/cosmos_predict1/diffusion/training/datasets/dataset_video.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/diffusion/training/datasets/dataset_gear.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import os +import pickle +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +import torch +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from torchvision import transforms as T +from tqdm import tqdm + +from cosmos_predict1.diffusion.training.datasets.dataset_utils import Resize_Preprocess, ToTensorVideo + + +class Dataset(Dataset): + def __init__( + self, + dataset_dir, + sequence_interval, + num_frames, + video_size, + start_frame_interval=1, + ): + """Dataset class for loading image-text-to-video generation data. + + Args: + dataset_dir (str): Base path to the dataset directory + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + video_size (list): Target size [H,W] for video frames + + Returns dict with: + - video: RGB frames tensor [T,C,H,W] + - video_name: Dict with episode/frame metadata + """ + + super().__init__() + self.dataset_dir = dataset_dir + self.start_frame_interval = start_frame_interval + self.sequence_interval = sequence_interval + self.sequence_length = num_frames + + video_dir = os.path.join(self.dataset_dir, "videos") + self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] + # print(f"{len(self.video_paths)} trajectories in total") + print(f"{len(self.video_paths)} videos in total") + + # self.t5_dir = os.path.join(self.dataset_dir, "labels") + self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") + self.samples = self._init_samples(self.video_paths) + self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) + print(f"{len(self.samples)} samples in total") + self.wrong_number = 0 + self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) + + def __str__(self): + return f"{len(self.video_paths)} samples from {self.dataset_dir}" + + def _init_samples(self, video_paths): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_video_path = { + executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths + } + for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): + samples.extend(future.result()) + return samples + + def _load_and_process_video_path(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + n_frames = len(vr) + + samples = [] + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["video_path"] = video_path + sample["t5_embedding_path"] = os.path.join( + # self.t5_dir, os.path.basename(video_path).replace(".mp4", ".npy") + self.t5_dir, + os.path.basename(video_path).replace(".mp4", ".pickle"), + ) + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all() + assert (np.array(frame_ids) >= 0).all() + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + try: + fps = vr.get_avg_fps() + except Exception: # failed to read FPS + fps = 24 + return frame_data, fps + + def _get_frames(self, video_path, frame_ids): + frames, fps = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) + frames = self.preprocess(frames) + frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) + return frames, fps + + def __getitem__(self, index): + try: + sample = self.samples[index] + video_path = sample["video_path"] + frame_ids = sample["frame_ids"] + + data = dict() + + video, fps = self._get_frames(video_path, frame_ids) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + data["video"] = video + data["video_name"] = { + "video_path": video_path, + "t5_embedding_path": sample["t5_embedding_path"], + "start_frame_id": str(frame_ids[0]), + } + + # Just add these to fit the interface + # t5_embedding = np.load(sample["t5_embedding_path"])[0] + with open(sample["t5_embedding_path"], "rb") as f: + t5_embedding = pickle.load(f)[0] # [n_tokens, 1024] + n_tokens = t5_embedding.shape[0] + if n_tokens < 512: + t5_embedding = np.concatenate( + [t5_embedding, np.zeros((512 - n_tokens, 1024), dtype=np.float32)], axis=0 + ) + t5_text_mask = torch.zeros(512, dtype=torch.int64) + t5_text_mask[:n_tokens] = 1 + + data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) + data["t5_text_mask"] = t5_text_mask + data["fps"] = fps + data["image_size"] = torch.tensor([704, 1280, 704, 1280]) + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 704, 1280) + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset( + dataset_dir="datasets/cosmos_nemo_assets", + sequence_interval=1, + num_frames=57, + video_size=[240, 360], + ) + + indices = [0, 13, 200, -1] + for idx in indices: + data = dataset[idx] + print( + ( + f"{idx=} " + f"{data['video'].sum()=}\n" + f"{data['video'].shape=}\n" + f"{data['video_name']=}\n" + f"{data['t5_text_embeddings'].shape=}\n" + "---" + ) + ) diff --git a/cosmos_predict1/diffusion/training/functional/loss.py b/cosmos_predict1/diffusion/training/functional/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..27d138371e86ae2ede2044e2f175306bbc63f59a --- /dev/null +++ b/cosmos_predict1/diffusion/training/functional/loss.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import torch + +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul + + +def create_per_sample_loss_mask( + loss_masking_cfg: dict, + data_batch: dict, + x_shape: Tuple[int], + dtype: torch.dtype, + device: Union[str, torch.device] = "cuda", +): + """ + Creates a per-sample loss mask based on the given configuration and input data batch. + + This function generates a dictionary of loss masks for each specified key in the loss masking configuration. + For keys present in both the configuration and the data batch, the corresponding data batch value is used. + For keys present only in the configuration, a tensor of zeros with the specified shape is created. + Additionally, it computes loss mask weights for each key based on the configuration values and adjusts them + based on the presence of certain keys in the data batch, such as "skip_face" and "object_loss_map". + + Note: + - The original `loss_masking_cfg` and `data_batch` are not modified by this function. + - For image data, it is assumed that the channel is always the first dimension. + - `skip_face` is for face regions that should be skipped during training, the key is provided so that we can generate + diverse human and avoid collapse to a single face given certain prompts. The issue happens for getty projects, + where face distribution in the dataset is high unbalanced that single man face can be shown in more than 100+ images. + + Parameters: + loss_masking_cfg (dict): Configuration for loss masking, specifying which keys to include and their weights. + data_batch (dict): The batch of data containing actual data points and potential mask indicators like "skip_face". + x_shape (tuple): The shape of the input data, used to initialize zero masks for keys not in the data batch. + dtype (torch.dtype): The data type for the tensors in the loss masks. + device (str, optional): The device on which to create the tensors. Defaults to 'cuda'. + + Returns: + dict: A dictionary containing combined loss masks adjusted according to the `loss_masking_cfg` and `data_batch`. + + Raises: + AssertionError: If "skip_face" is not present in `data_batch`. + + Note: `create_combined_loss_mask` is assumed to be a separate function that combines individual loss masks into a + single mask or set of masks based on the given parameters. Its behavior should be documented separately. + """ + loss_mask_data: dict = {} + for key in loss_masking_cfg: + if key not in data_batch: + loss_mask_data[key] = torch.zeros((x_shape[0], 1, x_shape[2], x_shape[3]), device=device) + else: + loss_mask_data[key] = data_batch[key] + + if "skip_face" not in data_batch: + # When skip_face is not there in data_dict, use 0 as default. This will not skip any sample. + data_batch["skip_face"] = torch.zeros((x_shape[0],), dtype=dtype, device=device) + + loss_mask_weight: dict = {} + for k, v in loss_masking_cfg.items(): + loss_mask_weight[k] = torch.tensor(v, device=device).expand(data_batch["skip_face"].size()) + + if "human_face_mask" in loss_mask_weight: + loss_mask_weight["human_face_mask"] = (1 - data_batch["skip_face"]) * loss_mask_weight["human_face_mask"] + + if "object_loss_map" in data_batch: + loss_mask_weight["object_loss_map"] = torch.ones(data_batch["object_loss_map"].shape[0], device=device) + + return create_combined_loss_mask(loss_mask_data, x_shape, dtype, device, loss_mask_weight) + + +def create_combined_loss_mask(data, x_shape, dtype, device="cuda", loss_masking=None): + """ + Creates a combined loss mask from multiple input masks. + + This function combines several loss masks into a single mask. In regions where masks overlap, + the highest value is assigned. Non-overlapping regions are assigned a default value of 1. + Regions with a mask value of zero are explicitly zeroed out, which is essential for padded loss calculations. + + Example: + Given the following masks and weights: + mask1: [0, 1, 1, 1, 0, 0], weight: 2 + mask2: [1, 0, 1, 0, 0, 0], weight: 4 + mask3: [0, 1, 0, 0, 0, 0], weight: 0 + The resulting combined loss mask would be: + [4, 0, 4, 2, 1, 1] + + Parameters: + data (dict): Contains the loss masks and their weights. + x_shape (tuple): The shape of the output mask. + dtype: The data type for the output mask. + device: The device on which the output mask will be allocated. + loss_masking: The loss masking weight configuration. + + Returns: + torch.Tensor: The combined loss mask. + """ + + loss_mask = torch.ones(x_shape, dtype=dtype, device=device) + zero_mask = torch.ones(x_shape, dtype=dtype, device=device) + + if loss_masking: + for key in loss_masking: + # Repeat mask along channel's dimension. ndim=4 for images. + repeat_dims = (1, x_shape[1]) + tuple([1] * (data[key].ndim - 2)) + mask_key = torch.tile(data[key], dims=repeat_dims) + weight_key = loss_masking[key] + + # handle zero weight case + is_zero_weight = (weight_key == 0).float()[:, None, None, None] + zero_mask = zero_mask * ( + (1 - is_zero_weight) * torch.ones(x_shape, dtype=dtype, device=device) + + is_zero_weight * (1 - mask_key.bool().float()) + ) + + # calculate weights + no_mask_region = (mask_key.bool() == 0).float() + loss_mask = batch_mul(mask_key, weight_key) + batch_mul(no_mask_region, loss_mask) + + loss_mask_final = loss_mask * zero_mask + return loss_mask_final diff --git a/cosmos_predict1/diffusion/training/functional/lr_scheduler.py b/cosmos_predict1/diffusion/training/functional/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..579d6debaceeefb13d2304f7389090ca9b496a2d --- /dev/null +++ b/cosmos_predict1/diffusion/training/functional/lr_scheduler.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np + +from cosmos_predict1.utils import distributed, log + + +class TeroPolyScheduler: + def __init__( + self, + total_Mimg: int, + batch_size: int, + ref_Mimg: Optional[int] = None, + ref_batches: float = 70e3 / 1024, + max_lr_ratio: Optional[float] = 1.0, + min_lr_ratio: Optional[float] = None, + rampup_Mimg: float = 0, + rampdown_Mimg: int = 0, + verbosity_interval: int = 0, + formula: str = "poly", + poly_exp: float = 0.5, + ): + self.total_Mimg = total_Mimg + self.batch_size = batch_size * distributed.get_world_size() + self.ref_Mimg = ref_Mimg or ref_batches * batch_size / 1e6 + self.ref_batches = ref_batches + self.max_lr_ratio = max_lr_ratio + self.min_lr_ratio = min_lr_ratio + self.rampup_Mimg = rampup_Mimg + self.rampdown_Mimg = rampdown_Mimg + self.verbosity_interval = verbosity_interval + self.formula = formula + self.poly_exp = poly_exp + + self._model = None + + @property + def model(self): + return self._model + + @model.setter + def model(self, model): + self._model = model + + def schedule(self, n, **kwargs): + cur_Mimg = getattr(self.model, "sample_counter", 0) / 1e6 + + if self.formula == "constant": + lr = 1.0 + elif self.formula == "poly": + lr = max(cur_Mimg / self.ref_Mimg, 1e-8) ** -self.poly_exp + else: + raise ValueError(f'Invalid learning rate formula "{self.formula}"') + + if self.max_lr_ratio is not None: + lr = min(lr, self.max_lr_ratio) + if self.min_lr_ratio is not None: + lr = max(lr, self.min_lr_ratio) + + if self.rampup_Mimg > 0 and cur_Mimg < self.rampup_Mimg: + lr *= cur_Mimg / self.rampup_Mimg + if self.rampdown_Mimg > 0 and cur_Mimg > self.total_Mimg - self.rampdown_Mimg: + lr *= (self.total_Mimg - cur_Mimg) / self.rampdown_Mimg + + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaWarmUpCosineScheduler: + """ + A learning rate scheduler that combines warm-up with a cosine decay schedule for multiple cycles. + It supports different configurations for each cycle, including the number of warm-up steps, minimum + and maximum scaling factors for the learning rate. + + The scheduler is intended to be used with a base learning rate of 1.0, where the actual learning + rate at any step is the base learning rate multiplied by the scaling factor computed by the scheduler. + + Parameters: + warm_up_steps (list[int]): List of integers where each element represents the number of warm-up + steps for the corresponding cycle. + f_min (list[float]): List of the minimum scaling factors for each cycle after warm-up. + f_max (list[float]): List of the maximum scaling factors at the start and end of each cosine cycle. + f_start (list[float]): List of starting scaling factors for each warm-up phase. + cycle_lengths (list[int]): List of the total lengths of each cycle, including warm-up steps. + verbosity_interval (int, optional): Interval of training steps at which to print current step and + scaling factor information. Set to 0 by default to disable verbosity. + + Examples: + >>> scheduler = LambdaWarmUpCosineScheduler2( + warm_up_steps=[10, 10], + f_min=[0.1, 0.1], + f_max=[1.0, 1.0], + f_start=[0.01, 0.01], + cycle_lengths=[50, 50], + verbosity_interval=10) + >>> for step in range(100): + >>> lr_multiplier = scheduler(step) + >>> print(f"Step {step}: LR Multiplier = {lr_multiplier}") + """ + + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0.0 + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler): + """ + Linear instead of cosine decay for the main part of the cycle. + """ + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( + self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] + ) + self.last_f = f + return f diff --git a/cosmos_predict1/diffusion/training/models/extend_model.py b/cosmos_predict1/diffusion/training/models/extend_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb8ed2b37908fff47d1ec6dfddcce8a44756177 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/extend_model.py @@ -0,0 +1,576 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from statistics import NormalDist +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.config.base.conditioner import VideoCondBoolConfig +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.training.conditioner import DataType, VideoExtendCondition +from cosmos_predict1.diffusion.training.context_parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.training.models.model import DiffusionModel as BaseModel +from cosmos_predict1.diffusion.training.models.model import _broadcast, broadcast_condition +from cosmos_predict1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator +from cosmos_predict1.utils import log, misc + + +@dataclass +class VideoDenoisePrediction: + x0: torch.Tensor # clean data prediction + eps: Optional[torch.Tensor] = None # noise prediction + logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty + net_in: Optional[torch.Tensor] = None # input to the network + net_x0_pred: Optional[torch.Tensor] = None # prediction of x0 from the network + xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in + x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent + + +def normalize_condition_latent(condition_latent): + """Normalize the condition latent tensor to have zero mean and unit variance + Args: + condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W + """ + condition_latent_2D = rearrange(condition_latent, "b c t h w -> b c t (h w)") + mean = condition_latent_2D.mean(dim=-1) + std = condition_latent_2D.std(dim=-1) + # bct -> bct11 + mean = mean.unsqueeze(-1).unsqueeze(-1) + std = std.unsqueeze(-1).unsqueeze(-1) + condition_latent = (condition_latent - mean) / std + return condition_latent + + +class ExtendDiffusionModel(BaseModel): + def __init__(self, config): + super().__init__(config) + self.is_extend_model = True + + def get_data_and_condition( + self, data_batch: dict[str, Tensor], num_condition_t: Union[int, None] = None + ) -> Tuple[Tensor, VideoExtendCondition]: + raw_state, latent_state, condition = super().get_data_and_condition(data_batch) + if condition.data_type == DataType.VIDEO: + if self.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: + latent_state = self.sample_tokens_start_from_p_or_i(latent_state) + condition = self.add_condition_video_indicator_and_video_input_mask( + latent_state, condition, num_condition_t=num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + log.debug(f"condition.data_type {condition.data_type}") + return raw_state, latent_state, condition + + def draw_augment_sigma_and_epsilon( + self, size: int, condition: VideoExtendCondition, p_mean: float, p_std: float, multiplier: float + ) -> Tensor: + is_video_batch = condition.data_type == DataType.VIDEO + del condition + batch_size = size[0] + epsilon = torch.randn(size, **self.tensor_kwargs) + + gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) + cdf_vals = np.random.uniform(size=(batch_size)) + samples_interval_gaussian = [gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] + + log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") + sigma_B = torch.exp(log_sigma).to(**self.tensor_kwargs) + + sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) + epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) + return sigma_B, epsilon + + def augment_conditional_latent_frames( + self, + condition: VideoExtendCondition, + cfg_video_cond_bool: VideoCondBoolConfig, + gt_latent: Tensor, + condition_video_augment_sigma_in_inference: float = 0.001, + sigma: Tensor = None, + seed_inference: int = 1, + ) -> Union[VideoExtendCondition, Tensor]: + """This function is used to augment the condition input with noise + Args: + condition (VideoExtendCondition): condition object + condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. + condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. + cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config + gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + sigma (Tensor): noise level for the generation region + Returns: + VideoExtendCondition: updated condition object + condition_video_augment_sigma: sigma for the condition region, feed to the network + augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W + + """ + + if cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma": + # Training only, sample sigma for the condition region + augment_sigma, _ = self.draw_augment_sigma_and_epsilon( + gt_latent.shape, + condition, + cfg_video_cond_bool.augment_sigma_sample_p_mean, + cfg_video_cond_bool.augment_sigma_sample_p_std, + cfg_video_cond_bool.augment_sigma_sample_multiplier, + ) + noise = torch.randn(*gt_latent.shape, **self.tensor_kwargs) + + elif cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma_fixed": + # Inference only, use fixed sigma for the condition region + log.debug( + f"condition_video_augment_sigma_in_inference={condition_video_augment_sigma_in_inference}, sigma={sigma.flatten()[0]}" + ) + assert ( + condition_video_augment_sigma_in_inference is not None + ), "condition_video_augment_sigma_in_inference should be provided" + augment_sigma = condition_video_augment_sigma_in_inference + + if augment_sigma >= sigma.flatten()[0]: + # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together. + # This is achieved by setting all region as `generation`, i.e. value=0 + log.debug("augment_sigma larger than sigma or other frame, remove condition") + condition.condition_video_indicator = condition.condition_video_indicator * 0 + + augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs) + + # Inference, use fixed seed + noise = misc.arch_invariant_rand( + gt_latent.shape, + torch.float32, + self.tensor_kwargs["device"], + seed_inference, + ) + else: + raise ValueError(f"does not support {cfg_video_cond_bool.apply_corruption_to_condition_region}") + + # Now apply the augment_sigma to the gt_latent + + augment_latent = gt_latent + noise * augment_sigma.view(-1, 1, 1, 1, 1) + _, _, c_in_augment, c_noise_augment = self.scaling(sigma=augment_sigma) + + if cfg_video_cond_bool.condition_on_augment_sigma: # model takes augment_sigma as input + if condition.condition_video_indicator.sum() > 0: # has condition frames + condition.condition_video_augment_sigma = c_noise_augment + else: # no condition frames + condition.condition_video_augment_sigma = torch.zeros_like(c_noise_augment) + + # Multiply the whole latent with c_in_augment + augment_latent_cin = batch_mul(augment_latent, c_in_augment) + + # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect + _, _, c_in, _ = self.scaling(sigma=sigma) + augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) + + return condition, augment_latent_cin + + def drop_out_condition_region( + self, augment_latent: Tensor, noise_x: Tensor, cfg_video_cond_bool: VideoCondBoolConfig + ) -> Tensor: + """Use for CFG on input frames, we drop out the conditional region + There are two option: + 1. when we dropout, we set the region to be zero + 2. when we dropout, we set the region to be noise_x + """ + # Unconditional case, use for cfg + if cfg_video_cond_bool.cfg_unconditional_type == "zero_condition_region_condition_mask": + # Set the condition location input to be zero + augment_latent_drop = torch.zeros_like(augment_latent) + elif cfg_video_cond_bool.cfg_unconditional_type == "noise_x_condition_region": + # Set the condition location input to be noise_x, i.e., same as base model training + augment_latent_drop = noise_x + else: + raise NotImplementedError( + f"cfg_unconditional_type {cfg_video_cond_bool.cfg_unconditional_type} not implemented" + ) + return augment_latent_drop + + def denoise( + self, + noise_x: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_video_augment_sigma_in_inference: float = 0.001, + seed_inference: int = 1, + ) -> VideoDenoisePrediction: + """ + Denoise the noisy input tensor. + + Args: + noise_x (Tensor): Noisy input tensor. + sigma (Tensor): Noise level. + condition (VideoExtendCondition): Condition for denoising. + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + + Returns: + Tensor: Denoised output tensor. + """ + if condition.data_type == DataType.IMAGE: + pred = super().denoise(noise_x, sigma, condition) + log.debug(f"hit image denoise, noise_x shape {noise_x.shape}, sigma shape {sigma.shape}", rank0_only=False) + return VideoDenoisePrediction( + x0=pred.x0, + eps=pred.eps, + logvar=pred.logvar, + xt=noise_x, + ) + else: + assert ( + condition.gt_latent is not None + ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" + gt_latent = condition.gt_latent + cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool + + condition_latent = gt_latent + + if cfg_video_cond_bool.normalize_condition_latent: + condition_latent = normalize_condition_latent(condition_latent) + + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + condition, augment_latent = self.augment_conditional_latent_frames( + condition, + cfg_video_cond_bool, + condition_latent, + condition_video_augment_sigma_in_inference, + sigma, + seed_inference=seed_inference, + ) + condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) + augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) + gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) + + if not condition.video_cond_bool: + # Unconditional case, drop out the condition region + augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) + + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x + # Call the abse model + denoise_pred = super().denoise(new_noise_xt, sigma, condition) + + x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 + if cfg_video_cond_bool.compute_loss_for_condition_region: + # We also denoise the conditional region + x0_pred = denoise_pred.x0 + else: + x0_pred = x0_pred_replaced + + return VideoDenoisePrediction( + x0=x0_pred, + eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), + logvar=denoise_pred.logvar, + net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), + net_x0_pred=denoise_pred.x0, + xt=new_noise_xt, + x0_pred_replaced=x0_pred_replaced, + ) + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + return_noise: bool = False, + ) -> Tensor | Tuple[Tensor, Tensor]: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. + If this feature is stablized, we could consider to move this function to the base model. + + Args: + condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. + num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half + + add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + return_noise (bool): return the initial noise or not, used for ODE pairs generation + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + assert condition_latent is not None, "condition_latent should be provided" + + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + seed_inference=seed, # Use for noise of augment sigma + ) + + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), torch.float32, self.tensor_kwargs["device"], seed + ) + * self.sde.sigma_max + ) + if self.net.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + if return_noise: + if self.net.is_context_parallel_enabled: + x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + return samples, x_sigma_max / self.sde.sigma_max + + return samples + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + seed_inference: int = 1, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + Different from the base model, this function support condition latent as input, it will add the condition information into the condition and uncondition object. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. + - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + uncondition = self.add_condition_pose(data_batch, uncondition) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + seed_inference=seed_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + assert num_condition_t_max >= self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min + num_condition_t = torch.randint( + self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min, + num_condition_t_max + 1, + (1,), + ).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "random": + # Only in training + condition_rate = self.config.conditioner.video_cond_bool.random_conditon_rate + flag = torch.ones(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) * condition_rate + condition_video_indicator = torch.bernoulli(flag).type(latent_dtype).to(latent_state.device) + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def add_condition_pose(self, data_batch: Dict, condition: VideoExtendCondition) -> VideoExtendCondition: + """Add pose condition to the condition object. For camera control model + Args: + data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + assert ( + "plucker_embeddings" in data_batch or "plucker_embeddings_downsample" in data_batch.keys() + ), f"plucker_embeddings should be in data_batch. only find {data_batch.keys()}" + plucker_embeddings = ( + data_batch["plucker_embeddings"] + if "plucker_embeddings_downsample" not in data_batch.keys() + else data_batch["plucker_embeddings_downsample"] + ) + condition.condition_video_pose = rearrange(plucker_embeddings, "b t c h w -> b c t h w").contiguous() + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def sample_tokens_start_from_p_or_i(self, latent_state: torch.Tensor) -> torch.Tensor: + """Sample the PPP... from the IPPP... sequence, only for video sequence + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + Returns: + torch.Tensor: sampled PPP tensor in shape B,C,T,H,W + """ + B, C, T, H, W = latent_state.shape + latent_dtype = latent_state.dtype + T_target = self.state_shape[1] + latent_state_sample = torch.zeros((B, C, T_target, H, W), dtype=latent_dtype, device=latent_state.device) + t_start = torch.randint(0, T - T_target + 1, (1,)) + # broadcast to other device + latent_state_sample = latent_state[:, :, t_start : t_start + T_target].contiguous() + if parallel_state.is_initialized(): + latent_state_sample = _broadcast(latent_state_sample, to_tp=True, to_cp=True) + + return latent_state_sample + + +@diffusion_fsdp_class_decorator +class FSDPExtendDiffusionModel(ExtendDiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/extend_model_multiview.py b/cosmos_predict1/diffusion/training/models/extend_model_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..3c66f32f04015229b2723b06030171b8f5ead3a3 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/extend_model_multiview.py @@ -0,0 +1,446 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Tuple, Union + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.training.conditioner import DataType, VideoExtendCondition +from cosmos_predict1.diffusion.training.context_parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.training.models.extend_model import ( + ExtendDiffusionModel, + VideoDenoisePrediction, + normalize_condition_latent, +) +from cosmos_predict1.diffusion.training.models.model import DiffusionModel, broadcast_condition +from cosmos_predict1.diffusion.training.models.model_image import CosmosCondition, diffusion_fsdp_class_decorator +from cosmos_predict1.utils import log + + +class MultiviewExtendDiffusionModel(ExtendDiffusionModel): + def __init__(self, config): + super().__init__(config) + self.n_views = config.n_views + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + encoded_state = self.vae.encode(state) + encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data + return encoded_state + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + decoded_state = self.vae.decode(latent / self.sigma_data) + decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return decoded_state + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + else: + if parallel_state.is_initialized(): + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + log.debug("[CP] Split x0 and epsilon") + + x0 = rearrange(x0, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + epsilon = rearrange(epsilon, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) + epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) + + x0 = rearrange(x0, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + epsilon = rearrange(epsilon, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + output_batch, kendall_loss, pred_mse, edm_loss = super( + DiffusionModel, self + ).compute_loss_with_epsilon_and_sigma(data_batch, x0_from_data_batch, x0, condition, epsilon, sigma) + if not self.is_image_batch(data_batch): + if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: + kendall_loss *= parallel_state.get_context_parallel_world_size() + + return output_batch, kendall_loss, pred_mse, edm_loss + + def denoise( + self, + noise_x: Tensor, + sigma: Tensor, + condition: VideoExtendCondition, + condition_video_augment_sigma_in_inference: float = 0.001, + ) -> VideoDenoisePrediction: + """ + Denoise the noisy input tensor. + + Args: + noise_x (Tensor): Noisy input tensor. + sigma (Tensor): Noise level. + condition (VideoExtendCondition): Condition for denoising. + condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + + Returns: + Tensor: Denoised output tensor. + """ + if condition.data_type == DataType.IMAGE: + pred = super(DiffusionModel, self).denoise(noise_x, sigma, condition) + log.debug(f"hit image denoise, noise_x shape {noise_x.shape}, sigma shape {sigma.shape}", rank0_only=False) + return VideoDenoisePrediction( + x0=pred.x0, + eps=pred.eps, + logvar=pred.logvar, + xt=noise_x, + ) + else: + assert ( + condition.gt_latent is not None + ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" + gt_latent = condition.gt_latent + cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool + + condition_latent = gt_latent + + if cfg_video_cond_bool.normalize_condition_latent: + condition_latent = normalize_condition_latent(condition_latent) + + # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed + condition, augment_latent = self.augment_conditional_latent_frames( + condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma + ) + condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] + if parallel_state.get_context_parallel_world_size() > 1: + cp_group = parallel_state.get_context_parallel_group() + + condition_video_indicator = rearrange( + condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views + ) + augment_latent = rearrange(augment_latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + gt_latent = rearrange(gt_latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) + augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) + gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) + + condition_video_indicator = rearrange( + condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views + ) + augment_latent = rearrange(augment_latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + gt_latent = rearrange(gt_latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + if not condition.video_cond_bool: + # Unconditional case, drop out the condition region + augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) + # Compose the model input with condition region (augment_latent) and generation region (noise_x) + new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x + # Call the abse model + + denoise_pred = super(DiffusionModel, self).denoise(new_noise_xt, sigma, condition) + + x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 + if cfg_video_cond_bool.compute_loss_for_condition_region: + # We also denoise the conditional region + x0_pred = denoise_pred.x0 + else: + x0_pred = x0_pred_replaced + + return VideoDenoisePrediction( + x0=x0_pred, + eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), + logvar=denoise_pred.logvar, + net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), + net_x0_pred=denoise_pred.x0, + xt=new_noise_xt, + x0_pred_replaced=x0_pred_replaced, + ) + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + + condition_video_indicator = rearrange( + condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views + ) + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + num_condition_t = torch.randint(0, num_condition_t_max + 1, (1,)).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) + + condition_video_indicator = rearrange( + condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views + ) + + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + def get_x0_fn_from_batch_with_condition_latent( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + condition_latent: torch.Tensor = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + guidance_other: Union[float, None] = None, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + Different from the base model, this function support condition latent as input, it will add the condition information into the condition and uncondition object. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. + - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference + - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + condition.video_cond_bool = True + condition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, condition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + + uncondition.video_cond_bool = False if add_input_frames_guidance else True + uncondition = self.add_condition_video_indicator_and_video_input_mask( + condition_latent, uncondition, num_condition_t + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + uncondition = self.add_condition_pose(data_batch, uncondition) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + if guidance_other is not None: # and guidance_other != guidance: + import copy + + assert not parallel_state.is_initialized(), "Parallel state not supported with two guidances." + condition_other = copy.deepcopy(uncondition) + condition_other.trajectory = condition.trajectory + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + cond_other_x0 = self.denoise( + noise_x, + sigma, + condition_other, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + guidance_other * (cond_other_x0 - uncond_x0) + + else: + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise( + noise_x, + sigma, + condition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + uncond_x0 = self.denoise( + noise_x, + sigma, + uncondition, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + ).x0_pred_replaced + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + condition_latent: Union[torch.Tensor, None] = None, + num_condition_t: Union[int, None] = None, + condition_video_augment_sigma_in_inference: float = None, + add_input_frames_guidance: bool = False, + guidance_other: Union[float, None] = None, + ) -> Tensor: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. + If this feature is stablized, we could consider to move this function to the base model. + + Args: + condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. + num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half + + add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if is_image_batch: + log.debug("image batch, call base model generate_samples_from_batch") + return super().generate_samples_from_batch( + data_batch, + guidance=guidance, + seed=seed, + state_shape=state_shape, + n_sample=n_sample, + is_negative_prompt=is_negative_prompt, + num_steps=num_steps, + ) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + else: + log.debug(f"Default Video state shape is used. {self.state_shape}") + state_shape = self.state_shape + + assert condition_latent is not None, "condition_latent should be provided" + x0_fn = self.get_x0_fn_from_batch_with_condition_latent( + data_batch, + guidance, + is_negative_prompt=is_negative_prompt, + condition_latent=condition_latent, + num_condition_t=num_condition_t, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + add_input_frames_guidance=add_input_frames_guidance, + guidance_other=guidance_other, + ) + + generator = torch.Generator(device=self.tensor_kwargs["device"]) + generator.manual_seed(seed) + x_sigma_max = ( + torch.randn(n_sample, *state_shape, **self.tensor_kwargs, generator=generator) * self.sde.sigma_max + ) + + if self.net.is_context_parallel_enabled: + x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) + if self.net.is_context_parallel_enabled: + samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return samples + + +@diffusion_fsdp_class_decorator +class FSDPExtendDiffusionModel(MultiviewExtendDiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/interpolator.py b/cosmos_predict1/diffusion/training/models/interpolator.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1f883d69dc5b5976549fff2f5410647fa25ae5 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/interpolator.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.training.conditioner import DataType, VideoExtendCondition +from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel +from cosmos_predict1.diffusion.training.models.model import DiffusionModel as BaseModel +from cosmos_predict1.diffusion.training.models.model import broadcast_condition +from cosmos_predict1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator +from cosmos_predict1.utils import log + + +class InterpolatorDiffusionModel(ExtendDiffusionModel): + def __init__(self, config): + super().__init__(config) + self.is_extend_model = True + self.num_valid_latents = config.latent_shape[1] - config.num_latents_to_drop + self.pixel_chunk_duration = config.vae.video_vae.pixel_chunk_duration + self.input_image_key = getattr(self.config, "input_image_key", None) + self.input_data_key = self.config.input_data_key + + def get_data_and_condition( + self, data_batch: dict[str, Tensor], num_condition_t: Union[int, None] = None + ) -> Tuple[Tensor, VideoExtendCondition]: + raw_state, latent_state, condition = BaseModel.get_data_and_condition(self, data_batch) + num_valid_frames = raw_state.shape[2] - self.pixel_chunk_duration + 1 + raw_state, latent_state = ( + raw_state[:, :, :num_valid_frames, ...], + latent_state[:, :, : self.num_valid_latents, ...], + ) # [B, C, T, H, W] + raw_state, latent_state = raw_state.contiguous(), latent_state.contiguous() + if condition.data_type == DataType.VIDEO: + if self.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: + latent_state = self.sample_tokens_start_from_p_or_i(latent_state) + condition = self.add_condition_video_indicator_and_video_input_mask( + latent_state, condition, num_condition_t=1 + ) + if self.config.conditioner.video_cond_bool.add_pose_condition: + condition = self.add_condition_pose(data_batch, condition) + log.debug(f"condition.data_type {condition.data_type}") + return raw_state, latent_state, condition + + def add_condition_video_indicator_and_video_input_mask( + self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None + ) -> VideoExtendCondition: + """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. + condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. + condition_video_input_mask will be concat with the input for the network. + Args: + latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W + condition (VideoExtendCondition): condition object + num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" + Returns: + VideoExtendCondition: updated condition object + """ + T = latent_state.shape[2] + latent_dtype = latent_state.dtype + condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( + latent_dtype + ) # 1 for condition region + if self.config.conditioner.video_cond_bool.condition_location == "first_n": + # Only in inference to decide the condition region + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + # Should be used for both training and inference. The first and last frame will be condition frames. + assert num_condition_t is not None, "num_condition_t should be provided" + assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" + log.info( + f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" + ) + condition_video_indicator[:, :, :num_condition_t] += 1.0 + condition_video_indicator[:, :, -num_condition_t:] += 1.0 + elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": + # Only in training + num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max + assert ( + num_condition_t_max <= T + ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" + assert num_condition_t_max >= self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min + num_condition_t = torch.randint( + self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min, + num_condition_t_max + 1, + (1,), + ).item() + condition_video_indicator[:, :, :num_condition_t] += 1.0 + + elif self.config.conditioner.video_cond_bool.condition_location == "random": + # Only in training + condition_rate = self.config.conditioner.video_cond_bool.random_conditon_rate + flag = torch.ones(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) * condition_rate + condition_video_indicator = torch.bernoulli(flag).type(latent_dtype).to(latent_state.device) + else: + raise NotImplementedError( + f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" + ) + condition.gt_latent = latent_state + condition.condition_video_indicator = condition_video_indicator + + B, C, T, H, W = latent_state.shape + # Create additional input_mask channel, this will be concatenated to the input of the network + # See design doc section (Implementation detail A.1 and A.2) for visualization + ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) + assert condition.video_cond_bool is not None, "video_cond_bool should be set" + + # The input mask indicate whether the input is conditional region or not + if condition.video_cond_bool: # Condition one given video frames + condition.condition_video_input_mask = ( + condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding + ) + else: # Unconditional case, use for cfg + condition.condition_video_input_mask = zeros_padding + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + return condition + + +@diffusion_fsdp_class_decorator +class FSDPInterpolatorDiffusionModel(InterpolatorDiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/model.py b/cosmos_predict1/diffusion/training/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e1f7892451d410e5496540454fe8e08c37ff49 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/model.py @@ -0,0 +1,662 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union + +import amp_C +import torch +from apex.multi_tensor_apply import multi_tensor_applier +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.distributed import broadcast_object_list, get_process_group_ranks +from torch.distributed.utils import _verify_param_shape_across_processes + +from cosmos_predict1.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS +from cosmos_predict1.diffusion.training.conditioner import BaseVideoCondition, DataType +from cosmos_predict1.diffusion.training.context_parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.training.models.model_image import CosmosCondition +from cosmos_predict1.diffusion.training.models.model_image import DiffusionModel as ImageModel +from cosmos_predict1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator +from cosmos_predict1.utils import distributed, log, misc + +l2_norm_impl = amp_C.multi_tensor_l2norm +multi_tensor_scale_impl = amp_C.multi_tensor_scale + +# key to check if the video data is normalized or image data is converted to video data +# to avoid apply normalization or augment image dimension multiple times +# It is due to we do not have normalization and augment image dimension in the dataloader and move it to the model +IS_PREPROCESSED_KEY = "is_preprocessed" + + +def robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor: + """ + Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. + + Args: + tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). + src (int): The source rank for the broadcast. Defaults to 0. + + Returns: + torch.Tensor: The broadcasted tensor on all ranks. + """ + # First, broadcast the shape of the tensor + if distributed.get_rank() == src: + shape = torch.tensor(tensor.shape).cuda() + else: + shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() + if is_check_shape: + _verify_param_shape_across_processes(pg, [shape]) + torch.distributed.broadcast(shape, src, group=pg) + + # Resize the tensor on non-src ranks if necessary + if distributed.get_rank() != src: + tensor = tensor.new_empty(shape.tolist()).type_as(tensor) + + # Now broadcast the tensor data + torch.distributed.broadcast(tensor, src, group=pg) + + return tensor + + +def _broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None: + """ + Broadcast the item from the minimum rank in the specified group(s). + Since global rank = tp_rank + cp_rank * tp_size + ... + First broadcast in the tp_group and then in the cp_group will + ensure that the item is broadcasted across ranks in cp_group and tp_group. + + Parameters: + - item: The item to broadcast (can be a torch.Tensor, str, or None). + - to_tp: Whether to broadcast to the tensor model parallel group. + - to_cp: Whether to broadcast to the context parallel group. + """ + if not parallel_state.is_initialized(): + return item + tp_group = parallel_state.get_tensor_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + + to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1 + to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1 + + if to_tp: + min_tp_rank = min(get_process_group_ranks(tp_group)) + + if to_cp: + min_cp_rank = min(get_process_group_ranks(cp_group)) + + if isinstance(item, torch.Tensor): # assume the device is cuda + # log.info(f"{item.shape}", rank0_only=False) + if to_tp: + # torch.distributed.broadcast(item, min_tp_rank, group=tp_group) + item = robust_broadcast(item, min_tp_rank, tp_group) + if to_cp: + # torch.distributed.broadcast(item, min_cp_rank, group=cp_group) + item = robust_broadcast(item, min_cp_rank, cp_group) + elif item is not None: + broadcastable_list = [item] + if to_tp: + # log.info(f"{broadcastable_list}", rank0_only=False) + broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group) + if to_cp: + broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group) + + item = broadcastable_list[0] + return item + + +def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: + condition_kwargs = {} + for k, v in condition.to_dict().items(): + if isinstance(v, torch.Tensor): + assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" + condition_kwargs[k] = _broadcast(v, to_tp=to_tp, to_cp=to_cp) + condition = type(condition)(**condition_kwargs) + return condition + + +class DiffusionModel(ImageModel): + def __init__(self, config): + super().__init__(config) + # Initialize trained_data_record with defaultdict, key: image, video, iteration + self.trained_data_record = { + "image": 0, + "video": 0, + "iteration": 0, + } + if parallel_state.is_initialized(): + self.data_parallel_size = parallel_state.get_data_parallel_world_size() + else: + self.data_parallel_size = 1 + + if self.config.adjust_video_noise: + self.video_noise_multiplier = math.sqrt(self.state_shape[1]) + else: + self.video_noise_multiplier = 1.0 + + def setup_data_key(self) -> None: + self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model + self.input_image_key = self.config.input_image_key + + def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool: + """We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch. + Another comes from a dataloader which we by default assumes as video_data for video model training. + """ + is_image = self.input_image_key in data_batch + is_video = self.input_data_key in data_batch + assert ( + is_image != is_video + ), "Only one of the input_image_key or input_data_key should be present in the data_batch." + return is_image + + def draw_training_sigma_and_epsilon(self, size: int, condition: BaseVideoCondition) -> Tensor: + sigma_B, epsilon = super().draw_training_sigma_and_epsilon(size, condition) + is_video_batch = condition.data_type == DataType.VIDEO + multiplier = self.video_noise_multiplier if is_video_batch else 1 + sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) + epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) + return sigma_B, epsilon + + @torch.no_grad() + def validation_step( + self, data: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + save generated videos + """ + raw_data, x0, condition = self.get_data_and_condition(data) + guidance = data["guidance"] + data = misc.to(data, **self.tensor_kwargs) + sample = self.generate_samples_from_batch( + data, + guidance=guidance, + # make sure no mismatch and also works for cp + state_shape=x0.shape[1:], + n_sample=x0.shape[0], + ) + sample = self.decode(sample) + gt = raw_data + caption = data["ai_caption"] + return {"gt": gt, "result": sample, "caption": caption}, torch.tensor([0]).to(**self.tensor_kwargs) + + def training_step(self, data_batch: Dict[str, Tensor], iteration: int) -> Tuple[Dict[str, Tensor] | Tensor]: + input_key = self.input_data_key # by default it is video key + if self.is_image_batch(data_batch): + input_key = self.input_image_key + batch_size = data_batch[input_key].shape[0] + self.trained_data_record["image" if self.is_image_batch(data_batch) else "video"] += ( + batch_size * self.data_parallel_size + ) + self.trained_data_record["iteration"] += 1 + return super().training_step(data_batch, iteration) + + def state_dict(self) -> Dict[str, Any]: + state_dict = super().state_dict() + state_dict["trained_data_record"] = self.trained_data_record + return state_dict + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + if "trained_data_record" in state_dict and hasattr(self, "trained_data_record"): + trained_data_record = state_dict.pop("trained_data_record") + if trained_data_record: + assert set(trained_data_record.keys()) == set(self.trained_data_record.keys()) + for k, v in trained_data_record.items(): + self.trained_data_record[k] = v + else: + log.warning("trained_data_record not found in the state_dict.") + return super().load_state_dict(state_dict, strict, assign) + + def _normalize_video_databatch_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: + """ + Normalizes video data in-place on a CUDA device to reduce data loading overhead. + + This function modifies the video data tensor within the provided data_batch dictionary + in-place, scaling the uint8 data from the range [0, 255] to the normalized range [-1, 1]. + + Warning: + A warning is issued if the data has not been previously normalized. + + Args: + data_batch (dict[str, Tensor]): A dictionary containing the video data under a specific key. + This tensor is expected to be on a CUDA device and have dtype of torch.uint8. + + Side Effects: + Modifies the 'input_data_key' tensor within the 'data_batch' dictionary in-place. + + Note: + This operation is performed directly on the CUDA device to avoid the overhead associated + with moving data to/from the GPU. Ensure that the tensor is already on the appropriate device + and has the correct dtype (torch.uint8) to avoid unexpected behaviors. + """ + input_key = self.input_data_key if input_key is None else input_key + # only handle video batch + if input_key in data_batch: + # Check if the data has already been normalized and avoid re-normalizing + if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: + assert torch.is_floating_point(data_batch[input_key]), "Video data is not in float format." + assert torch.all( + (data_batch[input_key] >= -1.0001) & (data_batch[input_key] <= 1.0001) + ), f"Video data is not in the range [-1, 1]. get data range [{data_batch[input_key].min()}, {data_batch[input_key].max()}]" + else: + assert data_batch[input_key].dtype == torch.uint8, "Video data is not in uint8 format." + data_batch[input_key] = data_batch[input_key].to(**self.tensor_kwargs) / 127.5 - 1.0 + data_batch[IS_PREPROCESSED_KEY] = True + + def _augment_image_dim_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: + input_key = self.input_image_key if input_key is None else input_key + if input_key in data_batch: + # Check if the data has already been augmented and avoid re-augmenting + if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: + assert ( + data_batch[input_key].shape[2] == 1 + ), f"Image data is claimed be augmented while its shape is {data_batch[input_key].shape}" + return + else: + data_batch[input_key] = rearrange(data_batch[input_key], "b c h w -> b c 1 h w").contiguous() + data_batch[IS_PREPROCESSED_KEY] = True + + def get_data_and_condition(self, data_batch: dict[str, Tensor]) -> Tuple[Tensor, BaseVideoCondition]: + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + input_key = self.input_data_key # by default it is video key + is_image_batch = self.is_image_batch(data_batch) + is_video_batch = not is_image_batch + + # Broadcast data and condition across TP and CP groups. + # sort keys to make sure the order is same, IMPORTANT! otherwise, nccl will hang! + local_keys = sorted(list(data_batch.keys())) + # log.critical(f"all keys {local_keys}", rank0_only=False) + for key in local_keys: + data_batch[key] = _broadcast(data_batch[key], to_tp=True, to_cp=is_video_batch) + + if is_image_batch: + input_key = self.input_image_key + + # Latent state + raw_state = data_batch[input_key] + latent_state = self.encode(raw_state).contiguous() + + # Condition + condition = self.conditioner(data_batch) + if is_image_batch: + condition.data_type = DataType.IMAGE + else: + condition.data_type = DataType.VIDEO + + # VAE has randomness. CP/TP group should have the same encoded output. + + latent_state = _broadcast(latent_state, to_tp=True, to_cp=is_video_batch) + condition = broadcast_condition(condition, to_tp=True, to_cp=is_video_batch) + + return raw_state, latent_state, condition + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + super().on_train_start(memory_format) + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + if sequence_parallel: + self.net.enable_sequence_parallel() + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + else: + if parallel_state.is_initialized(): + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + log.debug("[CP] Split x0 and epsilon") + x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) + epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) + + output_batch, kendall_loss, pred_mse, edm_loss = super().compute_loss_with_epsilon_and_sigma( + data_batch, x0_from_data_batch, x0, condition, epsilon, sigma + ) + if not self.is_image_batch(data_batch): + if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: + kendall_loss *= parallel_state.get_context_parallel_world_size() + + return output_batch, kendall_loss, pred_mse, edm_loss + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + if "guided_image" in data_batch: + # replacement trick that enables inpainting with base model + assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" + guide_image = data_batch["guided_image"] + guide_mask = data_batch["guided_mask"] + raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 + return raw_x0 + + return x0_fn + + def get_x_from_clean( + self, + in_clean_img: torch.Tensor, + sigma_max: float | None, + seed: int = 1, + ) -> Tensor: + """ + in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising + sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video + """ + if in_clean_img is None: + return None + generator = torch.Generator(device=self.tensor_kwargs["device"]) + generator.manual_seed(seed) + noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator) + if sigma_max is None: + sigma_max = self.sde.sigma_max + x_sigma_max = in_clean_img + noise * sigma_max + return x_sigma_max + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + x_sigma_max: Optional[torch.Tensor] = None, + sigma_max: float | None = None, + return_noise: bool = False, + ) -> Tensor | Tuple[Tensor, Tensor]: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + return_noise (bool): return the initial noise or not, used for ODE pairs generation + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + + x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) + + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * self.sde.sigma_max + ) + + if self.net.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + samples = self.sampler( + x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max, solver_option=solver_option + ) + if self.net.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + + if return_noise: + if self.net.is_context_parallel_enabled: + x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + return samples, x_sigma_max / self.sde.sigma_max + + return samples + + def on_after_backward(self, iteration: int = 0): + finalize_model_grads([self]) + + def get_grad_norm( + self, + norm_type: Union[int, float] = 2, + filter_fn: Callable[[str, torch.nn.Parameter], bool] | None = None, + ) -> float: + """Calculate the norm of gradients, handling model parallel parameters. + + This function is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ + with added functionality to handle model parallel parameters. + + Args: + norm_type (float or int): Type of norm to use. Can be 2 for L2 norm. + 'inf' for infinity norm is not supported. + filter_fn (callable, optional): Function to filter parameters for norm calculation. + Takes parameter name and parameter as input, returns True if this parameter is sharded else False. + + Returns: + float: Total norm of the parameters (viewed as a single vector). + + Note: + - Uses NVIDIA's multi-tensor applier for efficient norm calculation. + - Handles both model parallel and non-model parallel parameters separately. + - Currently only supports L2 norm (norm_type = 2). + """ + # Get model parallel group if parallel state is initialized + if parallel_state.is_initialized(): + model_parallel_group = parallel_state.get_model_parallel_group() + else: + model_parallel_group = None + + # Default filter function to identify tensor parallel parameters + if filter_fn is None: + + def is_tp(name, param): + return ( + any(key in name for key in ["to_q.0", "to_k.0", "to_v.0", "to_out.0", "layer1", "layer2"]) + and "_extra_state" not in name + ) + + filter_fn = is_tp + + # Separate gradients into model parallel and non-model parallel + without_mp_grads_for_norm = [] + with_mp_grads_for_norm = [] + for name, param in self.named_parameters(): + if param.grad is not None: + if filter_fn(name, param): + with_mp_grads_for_norm.append(param.grad.detach()) + else: + without_mp_grads_for_norm.append(param.grad.detach()) + + # Only L2 norm is currently supported + if norm_type != 2.0: + raise NotImplementedError(f"Norm type {norm_type} is not supported. Only L2 norm (2.0) is implemented.") + + # Calculate L2 norm using NVIDIA's multi-tensor applier + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + + # Calculate norm for non-model parallel gradients + without_mp_grad_norm = torch.tensor([0], dtype=torch.float, device="cuda") + if without_mp_grads_for_norm: + without_mp_grad_norm, _ = multi_tensor_applier( + l2_norm_impl, + dummy_overflow_buf, + [without_mp_grads_for_norm], + False, # no per-parameter norm + ) + + # Calculate norm for model parallel gradients + with_mp_grad_norm = torch.tensor([0], dtype=torch.float, device="cuda") + if with_mp_grads_for_norm: + with_mp_grad_norm, _ = multi_tensor_applier( + l2_norm_impl, + dummy_overflow_buf, + [with_mp_grads_for_norm], + False, # no per-parameter norm + ) + + # Square the norms as we'll be summing across model parallel GPUs + total_without_mp_norm = without_mp_grad_norm**2 + total_with_mp_norm = with_mp_grad_norm**2 + + # Sum across all model-parallel GPUs + torch.distributed.all_reduce(total_with_mp_norm, op=torch.distributed.ReduceOp.SUM, group=model_parallel_group) + + # Combine norms from model parallel and non-model parallel gradients + total_norm = (total_with_mp_norm.item() + total_without_mp_norm.item()) ** 0.5 + + return total_norm + + def clip_grad_norm_(self, max_norm: float): + """ + This function performs gradient clipping to prevent exploding gradients. + It calculates the total norm of the gradients, and if it exceeds the + specified max_norm, scales the gradients down proportionally. + + Args: + max_norm (float): The maximum allowed norm for the gradients. + + Returns: + torch.Tensor: The total norm of the gradients before clipping. + + Note: + This implementation uses NVIDIA's multi-tensor applier for efficiency. + """ + # Collect gradients from all parameters that require gradients + grads = [] + for param in self.parameters(): + if param.grad is not None: + grads.append(param.grad.detach()) + + # Calculate the total norm of the gradients + total_norm = self.get_grad_norm() + + # Compute the clipping coefficient + clip_coeff = max_norm / (total_norm + 1.0e-6) + + # Apply gradient clipping if the total norm exceeds max_norm + if clip_coeff < 1.0: + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + # Apply the scaling to the gradients using multi_tensor_applier for efficiency + multi_tensor_applier(multi_tensor_scale_impl, dummy_overflow_buf, [grads, grads], clip_coeff) + + return torch.tensor([total_norm]) + + +def _allreduce_layernorm_grads(model: List[torch.nn.Module]): + """ + All-reduce the following layernorm grads: + - When tensor parallel is enabled, all-reduce grads of QK-layernorm + - When sequence parallel, all-reduce grads of AdaLN, t_embedder, additional_timestamp_embedder, + and affline_norm. + """ + sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + + if parallel_state.get_tensor_model_parallel_world_size() > 1: + grads = [] + for model_chunk in model: + for name, param in model_chunk.named_parameters(): + if not param.requires_grad: + continue + + if "to_q.1" in name or "to_k.1" in name: # TP # Q-layernorm # K-layernorm + # grad = param.main_grad + grad = param.grad + if grad is not None: + grads.append(grad.data) + + if sequence_parallel: # TP + SP + if ( + "t_embedder" in name + or "adaLN_modulation" in name + or "additional_timestamp_embedder" in name + or "affline_norm" in name + or "input_hint_block" in name + or "zero_blocks" in name + ): + # grad = param.main_grad + grad = param.grad + if grad is not None: + grads.append(grad.data) + + if grads: + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + +def finalize_model_grads(model: List[torch.nn.Module]): + """ + All-reduce layernorm grads for tensor/sequence parallelism. + Reference implementation: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/finalize_model_grads.py#L99 + """ + + _allreduce_layernorm_grads(model) + + +@diffusion_fsdp_class_decorator +class FSDPDiffusionModel(DiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/model_image.py b/cosmos_predict1/diffusion/training/models/model_image.py new file mode 100644 index 0000000000000000000000000000000000000000..11ff1503c086f828cdc4afdd26567fad96f57fe4 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/model_image.py @@ -0,0 +1,933 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from contextlib import contextmanager +from dataclasses import dataclass, fields +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, TypeVar + +import numpy as np +import torch +import torch.nn.functional as F +from megatron.core import parallel_state +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy, StateDictType +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +from torch.nn.modules.module import _IncompatibleKeys + +from cosmos_predict1.diffusion.functional.batch_ops import batch_mul +from cosmos_predict1.diffusion.module.blocks import FourierFeatures +from cosmos_predict1.diffusion.module.pretrained_vae import BaseVAE +from cosmos_predict1.diffusion.modules.denoiser_scaling import EDMScaling +from cosmos_predict1.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS, Sampler +from cosmos_predict1.diffusion.training.functional.loss import create_per_sample_loss_mask +from cosmos_predict1.diffusion.training.utils.fsdp_helper import apply_fsdp_checkpointing, hsdp_device_mesh +from cosmos_predict1.diffusion.training.utils.optim_instantiate import get_base_scheduler +from cosmos_predict1.diffusion.types import DenoisePrediction +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.ema import FastEmaModelUpdater +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate +from cosmos_predict1.utils.model import Model + + +@dataclass +class CosmosCondition: + crossattn_emb: torch.Tensor + crossattn_mask: torch.Tensor + padding_mask: Optional[torch.Tensor] = None + scalar_feature: Optional[torch.Tensor] = None + + def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: + return {f.name: getattr(self, f.name) for f in fields(self)} + + +class DiffusionModel(Model): + def __init__(self, config): + super().__init__() + + self.config = config + + # how many sample have been processed + self.sample_counter = 0 + self.precision = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[config.precision] + self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} + log.warning(f"DiffusionModel: precision {self.precision}") + # Timer passed to network to detect slow ranks. + # 1. set data keys and data information + self.sigma_data = config.sigma_data + self.state_shape = list(config.latent_shape) + self.setup_data_key() + + # 2. setup up diffusion processing and scaling~(pre-condition), sampler + self.sde = lazy_instantiate(config.sde) + self.sampler = Sampler() + self.scaling = EDMScaling(self.sigma_data) + + # 3. vae + with misc.timer("DiffusionModel: set_up_vae"): + self.vae: BaseVAE = lazy_instantiate(config.vae) + assert ( + self.vae.latent_ch == self.state_shape[0] + ), f"latent_ch {self.vae.latent_ch} != state_shape {self.state_shape[0]}" + + # 4. Set up loss options, including loss masking, loss reduce and loss scaling + self.loss_masking: Optional[Dict] = config.loss_masking + self.loss_reduce = getattr(config, "loss_reduce", "mean") + assert self.loss_reduce in ["mean", "sum"] + self.loss_scale = getattr(config, "loss_scale", 1.0) + log.critical(f"Using {self.loss_reduce} loss reduce with loss scale {self.loss_scale}") + log.critical(f"Enable loss masking: {config.loss_mask_enabled}") + + # 5. diffusion neural networks part + self.set_up_model() + + def setup_data_key(self) -> None: + self.input_data_key = self.config.input_data_key + + def build_model(self) -> torch.nn.ModuleDict: + config = self.config + net = lazy_instantiate(config.net) + conditioner = lazy_instantiate(config.conditioner) + logvar = torch.nn.Sequential( + FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) + ) + + return torch.nn.ModuleDict( + { + "net": net, + "conditioner": conditioner, + "logvar": logvar, + } + ) + + @misc.timer("DiffusionModel: set_up_model") + def set_up_model(self): + config = self.config + self.model = self.build_model() + if config.ema.enabled: + with misc.timer("DiffusionModel: instantiate ema"): + config.ema.model = self.model + self.model_ema = lazy_instantiate(config.ema) + config.ema.model = None + else: + self.model_ema = None + + @property + def net(self): + return self.model.net + + @property + def conditioner(self): + return self.model.conditioner + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + """ + update the model_ema + """ + if self.config.ema.enabled: + self.model_ema.update_average(self.model, iteration) + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + if self.config.ema.enabled: + self.model_ema.to(dtype=torch.float32) + if hasattr(self.vae, "reset_dtype"): + self.vae.reset_dtype() + self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) + + if hasattr(self.config, "use_torch_compile") and self.config.use_torch_compile: # compatible with old config + if torch.__version__ < "2.3": + log.warning( + "torch.compile in Pytorch version older than 2.3 doesn't work well with activation checkpointing.\n" + "It's very likely there will be no significant speedup from torch.compile.\n" + "Please use at least 24.04 Pytorch container." + ) + # Increasing cache size. It's required because of the model size and dynamic input shapes resulting in + # multiple different triton kernels. For 28 TransformerBlocks, the cache limit of 256 should be enough for + # up to 9 different input shapes, as 28*9 < 256. If you have more Blocks or input shapes, and you observe + # graph breaks at each Block (detectable with torch._dynamo.explain) or warnings about + # exceeding cache limit, you may want to increase this size. + # Starting with 24.05 Pytorch container, the default value is 256 anyway. + # You can read more about it in the comments in Pytorch source code under path torch/_dynamo/cache_size.py. + torch._dynamo.config.accumulated_cache_size_limit = 256 + # dynamic=False means that a separate kernel is created for each shape. It incurs higher compilation costs + # at initial iterations, but can result in more specialized and efficient kernels. + # dynamic=True currently throws errors in pytorch 2.3. + self.model.net = torch.compile(self.model.net, dynamic=False, disable=not self.config.use_torch_compile) + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + """ + Compute loss givee epsilon and sigma + + This method is responsible for computing loss give epsilon and sigma. It involves: + 1. Adding noise to the input data using the SDE process. + 2. Passing the noisy data through the network to generate predictions. + 3. Computing the loss based on the difference between the predictions and the original data, \ + considering any configured loss weighting. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + x0_from_data_batch: raw image/video + x0: image/video latent + condition: text condition + epsilon: noise + sigma: noise level + + Returns: + tuple: A tuple containing four elements: + - dict: additional data that used to debug / logging / callbacks + - Tensor 1: kendall loss, + - Tensor 2: MSE loss, + - Tensor 3: EDM loss + + Raises: + AssertionError: If the class is conditional, \ + but no number of classes is specified in the network configuration. + + Notes: + - The method handles different types of conditioning + - The method also supports Kendall's loss + """ + # Get the mean and stand deviation of the marginal probability distribution. + mean, std = self.sde.marginal_prob(x0, sigma) + # Generate noisy observations + xt = mean + batch_mul(std, epsilon) # corrupted data + # make prediction + model_pred = self.denoise(xt, sigma, condition) + # loss weights for different noise levels + weights_per_sigma = self.get_per_sigma_loss_weights(sigma=sigma) + # extra weight for each sample, for example, aesthetic weight, camera weight + weights_per_sample = self.get_per_sample_weight(data_batch, x0_from_data_batch.shape[0]) + # extra loss mask for each sample, for example, human faces, hands + loss_mask_per_sample = self.get_per_sample_loss_mask(data_batch, x0_from_data_batch.shape, x0.shape) + pred_mse = (x0 - model_pred.x0) ** 2 * loss_mask_per_sample + edm_loss = batch_mul(pred_mse, weights_per_sigma * weights_per_sample) + if self.config.loss_add_logvar: + kendall_loss = batch_mul(edm_loss, torch.exp(-model_pred.logvar).view(-1)).flatten( + start_dim=1 + ) + model_pred.logvar.view(-1, 1) + else: + kendall_loss = edm_loss.flatten(start_dim=1) + output_batch = { + "x0": x0, + "xt": xt, + "sigma": sigma, + "weights_per_sigma": weights_per_sigma, + "weights_per_sample": weights_per_sample, + "loss_mask_per_sample": loss_mask_per_sample, + "condition": condition, + "model_pred": model_pred, + "mse_loss": pred_mse.mean(), + "edm_loss": edm_loss.mean(), + } + return output_batch, kendall_loss, pred_mse, edm_loss + + def training_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Performs a single training step for the diffusion model. + + This method is responsible for executing one iteration of the model's training. It involves: + 1. Adding noise to the input data using the SDE process. + 2. Passing the noisy data through the network to generate predictions. + 3. Computing the loss based on the difference between the predictions and the original data, \ + considering any configured loss weighting. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + + Returns: + tuple: A tuple containing two elements: + - dict: additional data that used to debug / logging / callbacks + - Tensor: The computed loss for the training step as a PyTorch Tensor. + + Raises: + AssertionError: If the class is conditional, \ + but no number of classes is specified in the network configuration. + + Notes: + - The method handles different types of conditioning + - The method also supports Kendall's loss + """ + # Get the input data to noise and denoise~(image, video) and the corresponding conditioner. + x0_from_data_batch, x0, condition = self.get_data_and_condition(data_batch) + + # Sample pertubation noise levels and N(0, 1) noises + sigma, epsilon = self.draw_training_sigma_and_epsilon(x0.size(), condition) + + output_batch, kendall_loss, pred_mse, edm_loss = self.compute_loss_with_epsilon_and_sigma( + data_batch, x0_from_data_batch, x0, condition, epsilon, sigma + ) + + if self.loss_reduce == "mean": + kendall_loss = kendall_loss.mean() * self.loss_scale + elif self.loss_reduce == "sum": + kendall_loss = kendall_loss.sum(dim=1).mean() * self.loss_scale + else: + raise ValueError(f"Invalid loss_reduce: {self.loss_reduce}") + + return output_batch, kendall_loss + + def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction: + """ + Performs denoising on the input noise data, noise level, and condition + + Args: + xt (torch.Tensor): The input noise data. + sigma (torch.Tensor): The noise level. + condition (CosmosCondition): conditional information, generated from self.conditioner + + Returns: + DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ + noise prediction (eps_pred) and optional confidence (logvar). + """ + + if getattr(self.config, "use_dummy_temporal_dim", False): + # When using video DiT model for image, we need to use a dummy temporal dimension. + xt = xt.unsqueeze(2) + + xt = xt.to(**self.tensor_kwargs) + sigma = sigma.to(**self.tensor_kwargs) + # get precondition for the network + c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) + + # forward pass through the network + net_output = self.net( + x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf + **condition.to_dict(), + ) + + logvar = self.model.logvar(c_noise) + x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) + + # get noise prediction based on sde + eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) + + if getattr(self.config, "use_dummy_temporal_dim", False): + x0_pred = x0_pred.squeeze(2) + eps_pred = eps_pred.squeeze(2) + + return DenoisePrediction(x0_pred, eps_pred, logvar) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + return self.vae.encode(state) * self.sigma_data + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + return self.vae.decode(latent / self.sigma_data) + + def draw_training_sigma_and_epsilon(self, x0_size: int, condition: Any) -> torch.Tensor: + del condition + batch_size = x0_size[0] + epsilon = torch.randn(x0_size, **self.tensor_kwargs) + return self.sde.sample_t(batch_size).to(**self.tensor_kwargs), epsilon + + def get_data_and_condition(self, data_batch: dict[str, torch.Tensor]) -> Tuple[torch.Tensor, CosmosCondition]: + """ + processing data batch draw from data loader and return data and condition that used for denoising task + + Returns: + raw_state (tensor): the image / video data that feed to vae + latent_state (tensor): nosie-free state, the vae latent state + condition (CosmosCondition): condition information for conditional generation. Generated from conditioner + """ + raw_state = data_batch[self.input_data_key] + latent_state = self.encode(raw_state) + condition = self.conditioner(data_batch) + return raw_state, latent_state, condition + + def get_per_sample_weight(self, data_batch: dict[str, torch.Tensor], batch_size: int): + r""" + extra weight for each sample, for example, aesthetic weight + Args: + data_batch: raw data batch draw from the training data loader. + batch_size: int, the batch size of the input data + """ + aesthetic_cfg = getattr(self.config, "aesthetic_finetuning", None) + if (aesthetic_cfg is not None) and getattr(aesthetic_cfg, "enabled", False): + sample_weight = data_batch["aesthetic_weight"] + else: + sample_weight = torch.ones(batch_size, **self.tensor_kwargs) + + camera_cfg = getattr(self.config, "camera_sample_weight", None) + if (camera_cfg is not None) and getattr(camera_cfg, "enabled", False): + sample_weight *= 1 + (data_batch["camera_attributes"][:, 1:].sum(dim=1) != 0) * (camera_cfg.weight - 1) + return sample_weight + + def get_per_sample_loss_mask(self, data_batch, raw_x_shape, latent_x_shape): + """ + extra loss mask for each sample, for example, human faces, hands. + + Args: + data_batch (dict): raw data batch draw from the training data loader. + raw_x_shape (tuple): shape of the input data. We need the raw_x_shape for necessary resize operation. + latent_x_shape (tuple): shape of the latent data + """ + if self.config.loss_mask_enabled: + raw_x_shape = [raw_x_shape[0], 1, *raw_x_shape[2:]] + weights = create_per_sample_loss_mask( + self.loss_masking, data_batch, raw_x_shape, torch.get_default_dtype(), "cuda" + ) + return F.interpolate(weights, size=latent_x_shape[2:], mode="bilinear") + + return 1.0 + + def get_per_sigma_loss_weights(self, sigma: torch.Tensor): + """ + Args: + sigma (tensor): noise level + + Returns: + loss weights per sigma noise level + """ + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + def generate_samples(self, batch_size: int, condition: CosmosCondition) -> torch.Tensor: + """ + Generate samples with given condition. It is WITHOUT classifier-free-guidance. + + Args: + batch_size (int): + condition (CosmosCondition): condition information generated from self.conditioner + """ + x_sigma_max = torch.randn(batch_size, *self.state_shape, **self.tensor_kwargs) * self.sde.sigma_max + + def x0_fn(x, t): + return self.denoise(x, t, condition).x0 # ODE function + + return self.sampler(x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max) + + def generate_cfg_samples( + self, batch_size: int, condition: CosmosCondition, uncondition: CosmosCondition, guidance=1.5 + ) -> torch.Tensor: + """ + Generate samples with with classifier-free-guidance. + + Args: + batch_size (int): + condition (CosmosCondition): condition information generated from self.conditioner + uncondition (CosmosCondition): uncondition information, possibily generated from self.conditioner + """ + x_sigma_max = torch.randn(batch_size, *self.state_shape, **self.tensor_kwargs) * self.sde.sigma_max + + def x0_fn(x, t): + cond_x0 = self.denoise(x, t, condition).x0 + uncond_x0 = self.denoise(x, t, uncondition).x0 + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return self.sampler(x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max) + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + return cond_x0 + guidance * (cond_x0 - uncond_x0) + + return x0_fn + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Optional[Tuple] = None, + n_sample: Optional[int] = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + ) -> torch.Tensor: + """ + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + """ + x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) + batch_size = n_sample or data_batch[self.input_data_key].shape[0] + state_shape = state_shape or self.state_shape + x_sigma_max = ( + misc.arch_invariant_rand( + (batch_size,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * self.sde.sigma_max + ) + return self.sampler( + x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max, num_steps=num_steps, solver_option=solver_option + ) + + @torch.no_grad() + def validation_step( + self, data: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Current code does nothing. + """ + return {}, torch.tensor(0).to(**self.tensor_kwargs) + + @torch.no_grad() + def forward(self, xt, t, condition: CosmosCondition): + """ + Performs denoising on the input noise data, noise level, and condition + + Args: + xt (torch.Tensor): The input noise data. + sigma (torch.Tensor): The noise level. + condition (CosmosCondition): conditional information, generated from self.conditioner + + Returns: + DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ + noise prediction (eps_pred) and optional confidence (logvar). + """ + return self.denoise(xt, t, condition) + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + """Creates the optimizer and scheduler for the model. + + Args: + config_model (ModelConfig): The config object for the model. + + Returns: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + """ + optimizer = lazy_instantiate(optimizer_config, model=self.model) + scheduler = get_base_scheduler(optimizer, self, scheduler_config) + return optimizer, scheduler + + def state_dict(self) -> Dict[str, Any]: + """ + Returns the current state of the model as a dictionary. + + Returns: + Dict: The current state of the model as a dictionary. + """ + return { + "model": self.model.state_dict(), + "ema": self.model_ema.state_dict() if self.config.ema.enabled else None, + } + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): + """ + Loads a state dictionary into the model and optionally its EMA counterpart. + Different from torch strict=False mode, the method will not raise error for unmatched state shape while raise warning. + + Parameters: + state_dict (Mapping[str, Any]): A dictionary containing separate state dictionaries for the model and + potentially for an EMA version of the model under the keys 'model' and 'ema', respectively. + strict (bool, optional): If True, the method will enforce that the keys in the state dict match exactly + those in the model and EMA model (if applicable). Defaults to True. + assign (bool, optional): If True and in strict mode, will assign the state dictionary directly rather than + matching keys one-by-one. This is typically used when loading parts of state dicts + or using customized loading procedures. Defaults to False. + """ + if strict: + # the converted tpsp checkpoint has "ema" and it is None + if self.config.ema.enabled and state_dict["ema"] is not None: + ema_results: _IncompatibleKeys = self.model_ema.load_state_dict( + state_dict["ema"], strict=strict, assign=assign + ) + reg_results: _IncompatibleKeys = self.model.load_state_dict( + state_dict["model"], strict=strict, assign=assign + ) + if self.config.ema.enabled and state_dict["ema"] is not None: + return _IncompatibleKeys( + ema_results.missing_keys + reg_results.missing_keys, + ema_results.unexpected_keys + reg_results.unexpected_keys, + ) + return reg_results + else: + from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model + + log.critical("load model in non-strict mode") + if "model" in state_dict: + log.critical(non_strict_load_model(self.model, state_dict["model"]), rank0_only=False) + else: + log.critical(non_strict_load_model(self.model, state_dict), rank0_only=False) + if self.config.ema.enabled and "ema" in state_dict and state_dict["ema"] is not None: + log.critical("load ema model in non-strict mode") + log.critical(non_strict_load_model(self.model_ema, state_dict["ema"]), rank0_only=False) + + def get_ckpt_postfix(self) -> Tuple[str, int, int]: + """Get the checkpoint file postfix. + + Args: + iteration (int): The current iteration number. + + Returns: + postfix (str): The postfix of the checkpoint file. + rank_to_save ema (int), we will not save each ema model in each rank, \ + ema model with same rate will be saved once + total_ema_num (int) + """ + total_ema_num = min(self.config.ema.num, distributed.get_world_size()) + rank = distributed.get_rank() + if rank == 0: + return "", 0, total_ema_num + if self.config.ema.enabled: + if rank < self.config.ema.num: + return f"_RANK{rank}", rank, total_ema_num + return "", 0, total_ema_num # use rank 0 to save the checkpoint + + @contextmanager + def ema_scope(self, context=None, is_cpu=False): + if self.config.ema.enabled: + self.model_ema.cache(self.model.parameters(), is_cpu=is_cpu) + self.model_ema.copy_to(self.model) + if context is not None: + log.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.config.ema.enabled: + self.model_ema.restore(self.model.parameters()) + if context is not None: + log.info(f"{context}: Restored training weights") + + +T = TypeVar("T", bound=DiffusionModel) + + +def diffusion_fsdp_class_decorator(base_class: Type[T]) -> Type[T]: + """ + Decorator for the FSDP class for the diffusion model, which handles the FSDP specific logic for the diffusion model. + """ + + class FSDPClass(base_class): + """ + Handle FSDP specific logic for the diffusion model. Including: + - FSDP model initialization + - FSDP model / optimizer save and loading + - Different from the original DiffusionModel, the impl of multi-rank EMA is a bit hacky. \ + We need to make sure sharded model weights for EMA and regular model are the same. + """ + + def __init__(self, config, fsdp_checkpointer: Any): + self.fsdp_checkpointer = fsdp_checkpointer + super().__init__(config) + + def set_up_model(self): + config = self.config + + # 1. build FSDP sharding strategy and device_mesh + strategy = { + "full": ShardingStrategy.FULL_SHARD, + "hybrid": ShardingStrategy.HYBRID_SHARD, + }[config.fsdp.sharding_strategy] + log.critical(f"Using {strategy} sharding strategy for FSDP") + + if config.fsdp.sharding_strategy == "hybrid": + sharding_group_size = getattr(config.fsdp, "sharding_group_size", 8) + device_mesh = hsdp_device_mesh( + sharding_group_size=sharding_group_size, + ) + shard_group = device_mesh.get_group(mesh_dim="shard") + replicate_group = device_mesh.get_group(mesh_dim="replicate") + fsdp_process_group = (shard_group, replicate_group) + else: + device_mesh = hsdp_device_mesh( + sharding_group_size=distributed.get_world_size(), + ) + shard_group = device_mesh.get_group(mesh_dim="shard") + fsdp_process_group = shard_group + + # We piggyback the `device_mesh` to megatron-core's `parallel_state` for global access. + # This is not megatron-core's original API. + parallel_state.fsdp_device_mesh = device_mesh + + def get_wrap_policy(_model): + if not hasattr(_model.net, "fsdp_wrap_block_cls"): + raise ValueError( + "Networks does not have fsdp_wrap_block_cls attribute, please check the net definition" + ) + fsdp_blocks_cls = _model.net.fsdp_wrap_block_cls + fsdp_blocks_cls = ( + list(fsdp_blocks_cls) if isinstance(fsdp_blocks_cls, (list, tuple, set)) else [fsdp_blocks_cls] + ) + log.critical(f"Using FSDP blocks {fsdp_blocks_cls}") + + log.critical(f"Using wrap policy {config.fsdp.policy}") + if config.fsdp.policy == "size": + min_num_params = getattr(config.fsdp, "min_num_params", 100) + log.critical(f"Using {min_num_params} as the minimum number of parameters for auto-wrap policy") + wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) + else: + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + + wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=set(fsdp_blocks_cls), + ) + return wrap_policy + + # 2. build naive pytorch model and load weights if exists + replica_idx, shard_idx = device_mesh.get_coordinate() + # 2.1 handle ema case first, since float32 is more expensive + if config.ema.enabled: + with misc.timer("Creating PyTorch model and loading weights for ema"): + model_ema = self.build_model().float() + model_ema.cuda().eval().requires_grad_(False) + if distributed.get_rank() == 0: + # only load model in rank0 to reduce network traffic + self.fsdp_checkpointer.load_model_during_init(model_ema, is_ema=True) + # sync ema model weights from rank0 + with misc.timer("Sync model states for EMA model"): + #! this is IMPORTANT, see the following comment about regular model for details + #! we broadcast the ema model first, since it is fp32 and costs more memory + distributed.sync_model_states(model_ema, device_mesh.get_group(mesh_dim="shard")) + torch.cuda.empty_cache() + distributed.sync_model_states(model_ema, device_mesh.get_group(mesh_dim="replicate")) + torch.cuda.empty_cache() + # for ema model with dfiferent rate, we download the model when necessary + if shard_idx == 0 and replica_idx > 0 and replica_idx < config.ema.num: + print("loading ema model in rank", replica_idx) + self.fsdp_checkpointer.load_model_during_init( + model_ema, + is_ema=True, + ema_id=replica_idx, + ) + print("finish loading ema model in rank", replica_idx) + # 2.1.2 create FSDP model for ema model + with misc.timer("Creating FSDP model for EMA model"): + self.model_ema = FSDP( + model_ema, + sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync + process_group=device_mesh.get_group(mesh_dim=1), + sharding_strategy=ShardingStrategy.FULL_SHARD, + auto_wrap_policy=get_wrap_policy(model_ema), + device_id=torch.cuda.current_device(), + limit_all_gathers=True, + ) + + # extra ema model upate logic to the model + self.model_ema_worker = FastEmaModelUpdater() + s = 0.1 + replica_idx, shard_idx = device_mesh.get_coordinate() + divider = 2**replica_idx if replica_idx < config.ema.num else 1 + if replica_idx < config.ema.num: + if shard_idx == 0: + print(f"EMA: rank {replica_idx}, rate {config.ema.rate / divider}") + s = config.ema.rate / divider + self.ema_exp_coefficient = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() + + torch.cuda.empty_cache() + + # 2.2 handle regular model + with misc.timer("Creating PyTorch model and loading weights for regular model"): + model = self.build_model().cuda().to(**self.tensor_kwargs) + + if distributed.get_rank() == 0: + # only load model in rank0 to reduce network traffic and sync later + self.fsdp_checkpointer.load_model_during_init(model, is_ema=False) + + #! overwrite the forward method so that it will invoke the FSDP-specific pre- and post-forward sharding logic + model.forward = super().training_step + #! this is IMPORTANT, though following two lines are identical to sync_module_states=True in FSDP + #! we do it twice so that following line can warm up and avoid OOM in aws 128+ nodes settings + #! qsh hypothesize that it is due to overhead of initialization of nccl network communication; + #! without it, peak mem : reg_model + ema_model + FSDP overhead + nccl communication initialization overhead + #! with it, peak men: reg_model + ema_model + FSDP overhead + #! it is tricky, but it works! + with misc.timer("Sync model states for regular model"): + distributed.sync_model_states(model, device_mesh.get_group(mesh_dim="shard")) + torch.cuda.empty_cache() + distributed.sync_model_states(model, device_mesh.get_group(mesh_dim="replicate")) + torch.cuda.empty_cache() + + with misc.timer("Creating FSDP model"): + self.model = FSDP( + model.to(**self.tensor_kwargs), + sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync + sharding_strategy=strategy, + auto_wrap_policy=get_wrap_policy(model), + process_group=fsdp_process_group, + limit_all_gathers=True, + ) + + if self.config.fsdp.checkpoint: + fsdp_blocks_cls = model.net.fsdp_wrap_block_cls + fsdp_blocks_cls = ( + list(fsdp_blocks_cls) + if isinstance(fsdp_blocks_cls, (list, tuple, set)) + else [fsdp_blocks_cls] + ) + log.critical(f"Applying FSDP checkpointing with FSDP blocks: {fsdp_blocks_cls}") + apply_fsdp_checkpointing(self.model, list_block_cls=fsdp_blocks_cls) + + torch.cuda.empty_cache() + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + del scheduler, optimizer + + if self.config.ema.enabled: + # calculate beta for EMA update + if iteration == 0: + beta = 0.0 + else: + i = iteration + 1 + beta = (1 - 1 / i) ** (self.ema_exp_coefficient + 1) + self.model_ema_worker.update_average(self.model, self.model_ema, beta=beta) + + def training_step( + self, data_batch: Dict[str, torch.Tensor], iteration: int + ) -> Tuple[Dict[str, torch.Tensor] | torch.Tensor]: + # ! Important!!! + # ! make sure the training step is the same as the forward method~(training_step in the super class) + # ! this is necessary to trigger the FSDP-specific pre- and post-forward sharding logic + return self.model(data_batch, iteration) + + def state_dict(self) -> Dict: + raise NotImplementedError( + "FSDPDiffModle does not support state_dict, use state_dict_model and FSDPCheckpointer" + ) + + @misc.timer("FSDP state_dict_model") + def state_dict_model(self) -> Dict: + with FSDP.summon_full_params(self.model): + pass + with FSDP.state_dict_type( + self.model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + ): + model_state = self.model.state_dict() + if self.config.ema.enabled: + with FSDP.summon_full_params(self.model_ema): + pass + with FSDP.state_dict_type( + self.model_ema, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + ema_model_state = self.model_ema.state_dict() + else: + ema_model_state = None + return { + "model": model_state, + "ema": ema_model_state, + } + + def load_state_dict(self, state_dict: Dict, strict: bool = True, assign: bool = False) -> None: + raise NotImplementedError("FSDPDiffModle does not support load_state_dict, using FSDPCheckpointer") + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + optimizer, scheduler = super().init_optimizer_scheduler(optimizer_config, scheduler_config) + self.fsdp_checkpointer.load_optim_scheduler_during_init( + self.model, + optimizer, + scheduler, + ) + return optimizer, scheduler + + @contextmanager + def ema_scope(self, context=None, is_cpu=False): + if self.config.ema.enabled: + self.model_ema_worker.cache(self.model.parameters(), is_cpu=is_cpu) + self.model_ema_worker.copy_to(src_model=self.model_ema, tgt_model=self.model) + if context is not None: + log.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.config.ema.enabled: + self.model_ema_worker.restore(self.model.parameters()) + if context is not None: + log.info(f"{context}: Restored training weights") + + def get_ckpt_postfix(self) -> Tuple[str, int]: + """Get the checkpoint file postfix. check FSDPCheckpointer for more details + + Args: + iteration (int): The current iteration number. + + Returns: + postfix (str): The postfix of the checkpoint file. + replicate_idx, shard_idx (int), current gpu replicate_idx, shard_idx in FSDP \ + we will not save each ema model in each GPU, \ + ema model with same rate will be saved once + total_ema_num (int) + """ + mesh_shape = parallel_state.fsdp_device_mesh.shape + total_ema_num = min(self.config.ema.num, mesh_shape[0]) + replicate_idx, shard_idx = parallel_state.fsdp_device_mesh.get_coordinate() + if replicate_idx == 0: + return "", 0, shard_idx, total_ema_num + if self.config.ema.enabled: + if replicate_idx < self.config.ema.num: + return f"_RANK{replicate_idx}", replicate_idx, shard_idx, total_ema_num + return "", replicate_idx, shard_idx, total_ema_num + + return FSDPClass + + +@diffusion_fsdp_class_decorator +class FSDPDiffusionModel(DiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/model_multiview.py b/cosmos_predict1/diffusion/training/models/model_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..8cae2d6b0561f69fdc706d579a357b6940674015 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/model_multiview.py @@ -0,0 +1,225 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import Tensor + +from cosmos_predict1.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.context_parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict1.diffusion.training.models.model import DiffusionModel, broadcast_condition +from cosmos_predict1.diffusion.training.models.model_image import CosmosCondition, diffusion_fsdp_class_decorator +from cosmos_predict1.utils import log, misc + + +class MultiviewDiffusionModel(DiffusionModel): + def __init__(self, config): + super().__init__(config) + self.n_views = config.n_views + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + encoded_state = self.vae.encode(state) + encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data + return encoded_state + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + decoded_state = self.vae.decode(latent / self.sigma_data) + decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + return decoded_state + + def compute_loss_with_epsilon_and_sigma( + self, + data_batch: dict[str, torch.Tensor], + x0_from_data_batch: torch.Tensor, + x0: torch.Tensor, + condition: CosmosCondition, + epsilon: torch.Tensor, + sigma: torch.Tensor, + ): + if self.is_image_batch(data_batch): + # Turn off CP + self.net.disable_context_parallel() + else: + if parallel_state.is_initialized(): + if parallel_state.get_context_parallel_world_size() > 1: + # Turn on CP + cp_group = parallel_state.get_context_parallel_group() + self.net.enable_context_parallel(cp_group) + log.debug("[CP] Split x0 and epsilon") + + x0 = rearrange(x0, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + epsilon = rearrange(epsilon, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) + epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) + + x0 = rearrange(x0, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + epsilon = rearrange(epsilon, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + output_batch, kendall_loss, pred_mse, edm_loss = super( + DiffusionModel, self + ).compute_loss_with_epsilon_and_sigma(data_batch, x0_from_data_batch, x0, condition, epsilon, sigma) + if not self.is_image_batch(data_batch): + if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: + kendall_loss *= parallel_state.get_context_parallel_world_size() + + return output_batch, kendall_loss, pred_mse, edm_loss + + def generate_samples_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + is_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + x_sigma_max: Optional[torch.Tensor] = None, + sigma_max: float | None = None, + guidance_other: Union[float, None] = None, + ) -> Tensor: + """ + Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to self.state_shape if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + """ + self._normalize_video_databatch_inplace(data_batch) + self._augment_image_dim_inplace(data_batch) + is_image_batch = self.is_image_batch(data_batch) + if n_sample is None: + input_key = self.input_image_key if is_image_batch else self.input_data_key + n_sample = data_batch[input_key].shape[0] + if state_shape is None: + if is_image_batch: + state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W + x0_fn = self.get_x0_fn_from_batch( + data_batch, guidance, is_negative_prompt=is_negative_prompt, guidance_other=guidance_other + ) + x_sigma_max = ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * self.sde.sigma_max + ) + if self.net.is_context_parallel_enabled: + x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) + + x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + samples = self.sampler( + x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max, solver_option=solver_option + ) + if self.net.is_context_parallel_enabled: + samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) + samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) + + return samples + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + is_negative_prompt: bool = False, + guidance_other: Union[float, None] = None, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + if is_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + to_cp = self.net.is_context_parallel_enabled + # For inference, check if parallel_state is initialized + if parallel_state.is_initialized(): + condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) + uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) + else: + assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." + + if guidance_other is not None: + # assume this is for inference time trajectory guidance for now + assert not parallel_state.is_initialized(), "Parallel state not supported with two guidances." + condition_other = copy.deepcopy(uncondition) + condition_other.trajectory = condition.trajectory + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + cond_other_x0 = self.denoise(noise_x, sigma, condition_other).x0 + + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + guidance_other * (cond_other_x0 - uncond_x0) + + if "guided_image" in data_batch: + assert False, "not supported" + return raw_x0 + + else: + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + if "guided_image" in data_batch: + # replacement trick that enables inpainting with base model + assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" + guide_image = data_batch["guided_image"] + guide_mask = data_batch["guided_mask"] + raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 + return raw_x0 + + return x0_fn + + +@diffusion_fsdp_class_decorator +class FSDPDiffusionModel(MultiviewDiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/models/model_peft.py b/cosmos_predict1/diffusion/training/models/model_peft.py new file mode 100644 index 0000000000000000000000000000000000000000..a25645e1d8f348e357d46539af0f248cacb30db7 --- /dev/null +++ b/cosmos_predict1/diffusion/training/models/model_peft.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Type, TypeVar + +from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel +from cosmos_predict1.diffusion.training.models.model import DiffusionModel as VideoDiffusionModel +from cosmos_predict1.diffusion.training.utils.layer_control.peft_control_config_parser import LayerControlConfigParser +from cosmos_predict1.diffusion.training.utils.peft.peft import add_lora_layers, setup_lora_requires_grad +from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType +from cosmos_predict1.utils import misc +from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate + +T = TypeVar("T") + + +def video_peft_decorator(base_class: Type[T]) -> Type[T]: + class PEFTVideoDiffusionModel(base_class): + def __init__(self, config: dict, fsdp_checkpointer=None): + super().__init__(config) + + @misc.timer("PEFTVideoDiffusionModel: set_up_model") + def set_up_model(self): + config = self.config + peft_control_config_parser = LayerControlConfigParser(config=config.peft_control) + peft_control_config = peft_control_config_parser.parse() + self.model = self.build_model() + if peft_control_config and peft_control_config["customization_type"] == CustomizationType.LORA: + add_lora_layers(self.model, peft_control_config) + num_lora_params = setup_lora_requires_grad(self.model) + if num_lora_params == 0: + raise ValueError("No LoRA parameters found. Please check the model configuration.") + if config.ema.enabled: + with misc.timer("PEFTDiffusionModel: instantiate ema"): + config.ema.model = self.model + self.model_ema = lazy_instantiate(config.ema) + config.ema.model = None + else: + self.model_ema = None + + def state_dict_model(self) -> Dict: + return { + "model": self.model.state_dict(), + "ema": self.model_ema.state_dict() if self.model_ema else None, + } + + return PEFTVideoDiffusionModel + + +@video_peft_decorator +class PEFTVideoDiffusionModel(VideoDiffusionModel): + pass + + +@video_peft_decorator +class PEFTExtendDiffusionModel(ExtendDiffusionModel): + pass diff --git a/cosmos_predict1/diffusion/training/module/blocks.py b/cosmos_predict1/diffusion/training/module/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..255ab483fd61706c1cda4e8457418ffbff373dbd --- /dev/null +++ b/cosmos_predict1/diffusion/training/module/blocks.py @@ -0,0 +1,1118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from megatron.core import parallel_state +from torch import nn +from transformer_engine.pytorch.attention import apply_rotary_pos_emb + +from cosmos_predict1.diffusion.module.attention import Attention, GPT2FeedForward +from cosmos_predict1.diffusion.training.tensor_parallel import gather_along_first_dim +from cosmos_predict1.utils import log + + +class SDXLTimesteps(nn.Module): + def __init__(self, num_channels: int = 320): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + in_dype = timesteps.dtype + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return emb.to(in_dype) + + +class SDXLTimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): + super().__init__() + log.critical( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) + else: + self.linear_2 = nn.Linear(out_features, out_features, bias=True) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_3D = emb + emb_B_D = sample + else: + emb_B_D = emb + adaln_lora_B_3D = None + + return emb_B_D, adaln_lora_B_3D + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches and embedding each + patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + - keep_spatio (bool): If True, the spatial dimensions are kept separate in the output tensor, otherwise, they are flattened. Default: False. + - legacy_patch_emb (bool): If True, applies 3D convolutional layers for video inputs, otherwise, use Linear! The legacy model is for backward compatibility. Default: True. + The output shape of the module depends on the `keep_spatio` flag. If `keep_spatio`=True, the output retains the spatial dimensions. + Otherwise, the spatial dimensions are flattened into a single dimension. + """ + + def __init__( + self, + spatial_patch_size, + temporal_patch_size, + in_channels=3, + out_channels=768, + bias=True, + keep_spatio=False, + legacy_patch_emb: bool = True, + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + assert keep_spatio, "Only support keep_spatio=True" + self.keep_spatio = keep_spatio + self.legacy_patch_emb = legacy_patch_emb + + if legacy_patch_emb: + self.proj = nn.Conv3d( + in_channels, + out_channels, + kernel_size=(temporal_patch_size, spatial_patch_size, spatial_patch_size), + stride=(temporal_patch_size, spatial_patch_size, spatial_patch_size), + bias=bias, + ) + self.out = Rearrange("b c t h w -> b t h w c") + else: + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + nn.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias + ), + ) + self.out = nn.Identity() + + def forward(self, x): + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return self.out(x) + + +class ExtraTokenPatchEmbed(PatchEmbed): + def __init__(self, *args, out_channels: int = 768, keep_spatio: bool = False, **kwargs): + assert keep_spatio, "ExtraTokenPatchEmbed only supports keep_spatio=True" + super().__init__(*args, out_channels=out_channels, keep_spatio=keep_spatio, **kwargs) + self.temporal_token = nn.Parameter(torch.randn(1, 1, 1, 1, out_channels)) + self.spatial_token = nn.Parameter(torch.randn(1, 1, 1, 1, out_channels)) + + def forward(self, x): + x_B_T_H_W_C = super().forward(x) + B, T, H, W, C = x_B_T_H_W_C.shape + x_B_T_H_W_C = torch.cat( + [ + x_B_T_H_W_C, + self.temporal_token.repeat(B, 1, H, W, 1), + ], + dim=1, + ) + x_B_T_H_W_C = torch.cat( + [ + x_B_T_H_W_C, + self.spatial_token.repeat(B, T, H, 1, 1), + ], + dim=3, + ) + return x_B_T_H_W_C + + +class ExpertChoiceMoEGate(nn.Module): + """ + ExpertChoiceMoEGate determines which tokens go + to which experts (and how much to weigh each expert). + + Args: + hidden_size (int): Dimensionality of input features. + num_experts (int): Number of experts (E). + capacity (int): Capacity (number of tokens) each expert can process (C). + """ + + def __init__( + self, + hidden_size: int, + num_experts: int, + capacity: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_experts = num_experts + self.capacity = capacity + + self.router = nn.Parameter(torch.empty((self.num_experts, self.hidden_size))) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.kaiming_uniform_(self.router) + + def forward(self, x: torch.Tensor): + """ + Args: + x (Tensor): Input of shape (B, S, D) + Returns: + gating (Tensor): Gating weights of shape (B, E, C), + where E = num_experts, C = capacity (top-k). + dispatch (Tensor): Dispatch mask of shape (B, E, C, S). + index (Tensor): Indices of top-k tokens for each expert, + shape (B, E, C). + """ + B, S, D = x.shape + E, C = self.num_experts, self.capacity + + # token-expert affinity scores + logits = torch.einsum("bsd,de->bse", x, self.router) + affinity = torch.nn.functional.softmax(logits, dim=-1) # (B, S, E) + + # gather topk tokens for each expert + affinity_t = affinity.transpose(1, 2) # (B, E, S) + + # select top-k tokens for each expert + gating, index = torch.topk(affinity_t, k=C, dim=-1) # (B, E, C) + + # one-hot dispatch mask + dispatch = torch.nn.functional.one_hot(index, num_classes=S).float() # (B, E, C, S) + + return gating, dispatch, index + + +class ExpertChoiceMoELayer(nn.Module): + """ + ExpertChoiceMoELayer uses the ExpertChoiceMoEGate to route tokens + to experts, process them, and then combine the outputs. + + Args: + gate_hidden_size (int): Dimensionality of input features. + ffn_hidden_size (int): Dimension of hidden layer in each expert feedforward (e.g., GPT2FeedForward). + num_experts (int): Number of experts (E). + capacity (int): Capacity (number of tokens) each expert can process (C). + expert_cls (nn.Module): The class to instantiate each expert. Defaults to GPT2FeedForward. + expert_kwargs (dict): Extra kwargs to pass to each expert class. + """ + + def __init__( + self, + gate_hidden_size: int, + ffn_hidden_size: int, + num_experts: int, + capacity: int, + expert_class: nn.Module = GPT2FeedForward, + expert_kwargs=None, + ): + super().__init__() + if not expert_kwargs: + expert_kwargs = {} + + self.gate_hidden_size = gate_hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.num_experts = num_experts + self.capacity = capacity + + self.gate = ExpertChoiceMoEGate(gate_hidden_size, num_experts, capacity) + + self.experts = nn.ModuleList( + [expert_class(gate_hidden_size, ffn_hidden_size, **expert_kwargs) for _ in range(num_experts)] + ) + + def forward(self, x: torch.Tensor): + """ + Args: + x (Tensor): Input of shape (B, S, D). + + Returns: + x_out (Tensor): Output of shape (B, S, D), after dispatching tokens + to experts and combining their outputs. + """ + B, S, D = x.shape + E, C = self.num_experts, self.capacity + + # gating: (B, E, C) + # dispatch: (B, E, C, S) + gating, dispatch, index = self.gate(x) + + # collect input tokens for each expert + x_in = torch.einsum("becs,bsd->becd", dispatch, x) + + # process through each expert + expert_outputs = [self.experts[e](x_in[:, e]) for e in range(E)] + + x_e = torch.stack(expert_outputs, dim=1) # (B, E, C, D) + + # gating: (B, E, C), dispatch: (B, E, C, S), x_e: (B, E, C, d) + # x_out: (B, S, D) + # each token is placed back to its location with weighting + x_out = torch.einsum("becs,bec,becd->bsd", dispatch, gating, x_e) + + return x_out + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size, + spatial_patch_size, + temporal_patch_size, + out_channels, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False) + ) + + self.sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + + def forward( + self, + x_BT_HW_D, + emb_B_D, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_3D is not None + shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( + 2, dim=1 + ) + else: + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + + B = emb_B_D.shape[0] + T = x_BT_HW_D.shape[0] // B + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D) + if self.sequence_parallel: + x_T_B_HW_D = rearrange(x_BT_HW_D, "(b t) hw d -> t b hw d", b=B, t=T) + x_T_B_HW_D = gather_along_first_dim(x_T_B_HW_D, parallel_state.get_tensor_model_parallel_group()) + x_BT_HW_D = rearrange(x_T_B_HW_D, "t b hw d -> (b t) hw d", b=B) + + x_BT_HW_D = self.linear(x_BT_HW_D) + return x_BT_HW_D + + def forward_with_memory_save( + self, + x_BT_HW_D_before_gate: torch.Tensor, + x_BT_HW_D_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_3D is not None + shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( + 2, dim=1 + ) + else: + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + + B = emb_B_D.shape[0] + T = x_BT_HW_D_before_gate.shape[0] // B + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + gate_BT_1_D = repeat(gate_L_B_D, "1 b d -> (b t) 1 d", t=T) + + def _fn(_x_before_gate, _x_skip): + previous_block_out = _x_skip + gate_BT_1_D * _x_before_gate + _x = modulate(self.norm_final(previous_block_out), shift_BT_D, scale_BT_D) + return self.linear(_x) + + return torch.utils.checkpoint.checkpoint(_fn, x_BT_HW_D_before_gate, x_BT_HW_D_skip, use_reentrant=False) + + +class VideoAttn(nn.Module): + """ + Implements video attention with optional cross-attention capabilities. + + This module supports both self-attention within the video frames and cross-attention + with an external context. It's designed to work with flattened spatial dimensions + to accommodate for video input. + + Attributes: + x_dim (int): Dimensionality of the input feature vectors. + context_dim (Optional[int]): Dimensionality of the external context features. + If None, the attention does not utilize external context. + num_heads (int): Number of attention heads. + bias (bool): If true, bias is added to the query, key, value projections. + x_format (str): The shape format of x tenosor. + n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. + """ + + def __init__( + self, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + bias: bool = False, + x_format: str = "BTHWD", + n_views: int = 1, + ) -> None: + super().__init__() + self.n_views = n_views + self.x_format = x_format + if self.x_format == "BTHWD": + qkv_format = "bshd" + elif self.x_format == "THWBD": + qkv_format = "sbhd" + else: + raise NotImplementedError(f"Unsupported x_format: {self.x_format}") + + self.attn = Attention( + x_dim, + context_dim, + num_heads, + x_dim // num_heads, + qkv_bias=bias, + qkv_norm="RRI", + out_bias=bias, + qkv_format=qkv_format, + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for video attention. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data. + context (Tensor): Context tensor of shape (B, M, D) or (M, B, D), where M is the sequence length of the context. + crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms. + rope_emb_L_1_1_D (Optional[Tensor]): Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. transformer_engine format + + Returns: + Tensor: The output tensor with applied attention, maintaining the input shape. + """ + + if self.x_format == "BTHWD": + if context is not None and self.n_views > 1: + x_B_T_H_W_D = rearrange(x, "b (v t) h w d -> (v b) t h w d", v=self.n_views) + context_B_M_D = rearrange(context, "b (v m) d -> (v b) m d", v=self.n_views) + else: + x_B_T_H_W_D = x + context_B_M_D = context + B, T, H, W, D = x_B_T_H_W_D.shape + x_B_THW_D = rearrange(x_B_T_H_W_D, "b t h w d -> b (t h w) d") + x_B_THW_D = self.attn(x_B_THW_D, context_B_M_D, crossattn_mask, rope_emb=rope_emb_L_1_1_D) + + # reshape it back to video format + x_B_T_H_W_D = rearrange(x_B_THW_D, "b (t h w) d -> b t h w d", h=H, w=W) + if context is not None and self.n_views > 1: + x_B_T_H_W_D = rearrange(x_B_T_H_W_D, "(v b) t h w d -> b (v t) h w d", v=self.n_views) + return x_B_T_H_W_D + elif self.x_format == "THWBD": + if context is not None and self.n_views > 1: + x_T_H_W_B_D = rearrange(x, "(v t) h w b d -> t h w (v b) d", v=self.n_views) + context_M_B_D = rearrange(context, "(v m) b d -> m (v b) d", v=self.n_views) + else: + x_T_H_W_B_D = x + context_M_B_D = context + T, H, W, B, D = x_T_H_W_B_D.shape + x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d") + x_THW_B_D = self.attn( + x_THW_B_D, + context_M_B_D, + crossattn_mask, + rope_emb=rope_emb_L_1_1_D, + ) + x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) + if context is not None and self.n_views > 1: + x_T_H_W_B_D = rearrange(x_T_H_W_B_D, "t h w (v b) d -> (v t) h w b d", v=self.n_views) + return x_T_H_W_B_D + else: + raise NotImplementedError(f"Unsupported x_format: {self.x_format}") + + +def checkpoint_norm_state(norm_state, x, scale, shift): + normalized = norm_state(x) + return normalized * (1 + scale) + shift + + +class DITBuildingBlock(nn.Module): + """ + DIT Building Block for constructing various types of attention or MLP blocks dynamically based on a specified block type. + + This class instantiates different types of buildig block / attn and MLP based on config, and applies crossponding forward pass during training. + + Attributes: + block_type (str): Type of block to be used ('spatial_sa', 'temporal_sa', 'cross_attn', 'full_attn', 'mlp'). + x_dim (int): Dimensionality of the input features. + context_dim (Optional[int]): Dimensionality of the external context, required for cross attention blocks. + num_heads (int): Number of attention heads. + mlp_ratio (float): Multiplier for the dimensionality of the MLP hidden layer compared to input. + spatial_win_size (int): Window size for spatial self-attention. + temporal_win_size (int): Window size for temporal self-attention. + bias (bool): Whether to include bias in attention and MLP computations. + mlp_dropout (float): Dropout rate for MLP blocks. + n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. + """ + + def __init__( + self, + block_type: str, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + mlp_ratio: float = 4.0, + window_sizes: list = [], + spatial_win_size: int = 1, + temporal_win_size: int = 1, + bias: bool = False, + mlp_dropout: float = 0.0, + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + n_views: int = 1, + ) -> None: + block_type = block_type.lower() + + super().__init__() + self.x_format = x_format + if block_type in ["cross_attn", "ca"]: + self.block = VideoAttn( + x_dim, + context_dim, + num_heads, + bias=bias, + x_format=self.x_format, + n_views=n_views, + ) + elif block_type in ["full_attn", "fa"]: + self.block = VideoAttn(x_dim, None, num_heads, bias=bias, x_format=self.x_format) + elif block_type in ["mlp", "ff"]: + self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias) + else: + raise ValueError(f"Unknown block type: {block_type}") + + self.block_type = block_type + self.use_adaln_lora = use_adaln_lora + + self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) + self.n_adaln_chunks = 3 + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(x_dim, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False), + ) + else: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False)) + + def forward_with_attn_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_mask + assert isinstance(self.block, VideoAttn), "only support VideoAttn impl" + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn(_x_before_gate, _x_skip, _context): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + # context = normalized_x if _context is None else _context + context = normalized_x if self.block.attn.is_selfattn else _context + return ( + self.block.attn.to_q[0](normalized_x), + self.block.attn.to_k[0](context), + self.block.attn.to_v[0](context), + previous_block_out, + ) + + q, k, v, previous_block_out = torch.utils.checkpoint.checkpoint( + _fn, x_before_gate, x_skip, crossattn_emb, use_reentrant=False + ) + + def attn_fn(_q, _k, _v): + q, k, v = map( + lambda t: rearrange( + t, + "b ... (n c) -> b ... n c", + n=self.block.attn.heads // self.block.attn.tp_size, + c=self.block.attn.dim_head, + ), + (_q, _k, _v), + ) + q = self.block.attn.to_q[1](q) + k = self.block.attn.to_k[1](k) + v = self.block.attn.to_v[1](v) + if self.block.attn.is_selfattn and rope_emb_L_1_1_D is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + + if self.block.attn.is_selfattn: + return q, k, v + + seq_dim = self.block.attn.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + return self.block.attn.attn_op( + q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None + ) # [B, Mq, H, V] + + assert self.block.attn.backend == "transformer_engine", "Only support transformer_engine backend for now." + + if self.block.attn.is_selfattn: + q, k, v = torch.utils.checkpoint.checkpoint(attn_fn, q, k, v, use_reentrant=False) + seq_dim = self.block.attn.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + softmax_attn_output = self.block.attn.attn_op( + q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None + ) # [B, Mq, H, V] + else: + softmax_attn_output = torch.utils.checkpoint.checkpoint(attn_fn, q, k, v, use_reentrant=False) + attn_out = self.block.attn.to_out(softmax_attn_output) + return _gate_L_B_D, attn_out, previous_block_out + + def forward_with_x_attn_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_mask + assert isinstance(self.block, VideoAttn) + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn(_x_before_gate, _x_skip, _context): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + # context = normalized_x if _context is None else _context + context = normalized_x if self.block.attn.is_selfattn else _context + return ( + self.block.attn.to_q[0](normalized_x), + self.block.attn.to_k[0](context), + self.block.attn.to_v[0](context), + previous_block_out, + ) + + q, k, v, previous_block_out = torch.utils.checkpoint.checkpoint( + _fn, x_before_gate, x_skip, crossattn_emb, use_reentrant=False + ) + + def x_attn_fn(_q, _k, _v): + q, k, v = map( + lambda t: rearrange( + t, + "b ... (n c) -> b ... n c", + n=self.block.attn.heads // self.block.attn.tp_size, + c=self.block.attn.dim_head, + ), + (_q, _k, _v), + ) + q = self.block.attn.to_q[1](q) + k = self.block.attn.to_k[1](k) + v = self.block.attn.to_v[1](v) + if self.block.attn.is_selfattn and rope_emb_L_1_1_D is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) + + seq_dim = self.block.attn.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + softmax_attn_output = self.block.attn.attn_op( + q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None + ) # [B, Mq, H, V] + return self.block.attn.to_out(softmax_attn_output) + + assert self.block.attn.backend == "transformer_engine", "Only support transformer_engine backend for now." + + attn_out = torch.utils.checkpoint.checkpoint(x_attn_fn, q, k, v, use_reentrant=False) + return _gate_L_B_D, attn_out, previous_block_out + + def forward_with_ffn_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_emb, crossattn_mask, rope_emb_L_1_1_D + assert isinstance(self.block, GPT2FeedForward) + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn(_x_before_gate, _x_skip): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + + assert self.block.dropout.p == 0.0, "we skip dropout to save memory" + + return self.block.layer1(normalized_x), previous_block_out + + intermediate_output, previous_block_out = torch.utils.checkpoint.checkpoint( + _fn, x_before_gate, x_skip, use_reentrant=False + ) + + def _fn2(_x): + _x = self.block.activation(_x) + return self.block.layer2(_x) + + return ( + _gate_L_B_D, + torch.utils.checkpoint.checkpoint(_fn2, intermediate_output, use_reentrant=False), + previous_block_out, + ) + + def forward_with_ffn_memory_save_upgrade( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + del crossattn_emb, crossattn_mask, rope_emb_L_1_1_D + assert isinstance(self.block, GPT2FeedForward) + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( + shift_B_D.unsqueeze(0), + scale_B_D.unsqueeze(0), + gate_B_D.unsqueeze(0), + ) + + def _fn2(_x): + _x = self.block.activation(_x) + return self.block.layer2(_x) + + def _fn(_x_before_gate, _x_skip): + previous_block_out = _x_skip + gate_L_B_D * _x_before_gate + if extra_per_block_pos_emb is not None: + previous_block_out = previous_block_out + extra_per_block_pos_emb + _normalized_x = self.norm_state(previous_block_out) + normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D + + assert self.block.dropout.p == 0.0, "we skip dropout to save memory" + + return _fn2(self.block.layer1(normalized_x)), previous_block_out + + output, previous_block_out = torch.utils.checkpoint.checkpoint(_fn, x_before_gate, x_skip, use_reentrant=False) + + return ( + _gate_L_B_D, + output, + previous_block_out, + ) + + def forward_with_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + if isinstance(self.block, VideoAttn): + if self.block.attn.is_selfattn: + fn = self.forward_with_attn_memory_save + else: + fn = self.forward_with_x_attn_memory_save + else: + # fn = self.forward_with_ffn_memory_save + fn = self.forward_with_ffn_memory_save_upgrade + return fn( + x_before_gate, + x_skip, + gate_L_B_D, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_per_block_pos_emb, + ) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for dynamically configured blocks with adaptive normalization. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D). + emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation. + crossattn_emb (Tensor): Tensor for cross-attention blocks. + crossattn_mask (Optional[Tensor]): Optional mask for cross-attention. + rope_emb_L_1_1_D (Optional[Tensor]): Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. transformer_engine format + + Returns: + Tensor: The output tensor after processing through the configured block and adaptive normalization. + """ + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + if self.x_format == "BTHWD": + shift_B_1_1_1_D, scale_B_1_1_1_D, gate_B_1_1_1_D = ( + shift_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), + scale_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), + gate_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), + ) + if self.block_type in ["spatial_sa", "temporal_sa", "window_attn", "ssa", "tsa", "wa"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["full_attn", "fa"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + context=None, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["cross_attn", "ca"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + crossattn_emb, + crossattn_mask, + ) + elif self.block_type in ["mlp", "ff"]: + x = x + gate_B_1_1_1_D * self.block( + self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, + ) + else: + raise ValueError(f"Unknown block type: {self.block_type}") + elif self.x_format == "THWBD": + shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = ( + shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + ) + + if self.block_type in ["mlp", "ff"]: + x = x + gate_1_1_1_B_D * self.block( + torch.utils.checkpoint.checkpoint( + checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False + ), + ) + elif self.block_type in ["full_attn", "fa"]: + x = x + gate_1_1_1_B_D * self.block( + torch.utils.checkpoint.checkpoint( + checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False + ), + context=None, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["cross_attn", "ca"]: + x = x + gate_1_1_1_B_D * self.block( + torch.utils.checkpoint.checkpoint( + checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False + ), + context=crossattn_emb, + crossattn_mask=crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + else: + raise ValueError(f"Unknown block type: {self.block_type}") + else: + raise NotImplementedError(f"Unsupported x_format: {self.x_format}") + return x + + +class GeneralDITTransformerBlock(nn.Module): + """ + This class is a wrapper for a list of DITBuildingBlock. + It's not essential, refactor it if needed. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + block_config: str, + mlp_ratio: float = 4.0, + window_sizes: list = [], + spatial_attn_win_size: int = 1, + temporal_attn_win_size: int = 1, + use_checkpoint: bool = False, + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + n_views: int = 1, + ): + super().__init__() + self.blocks = nn.ModuleList() + self.x_format = x_format + for block_type in block_config.split("-"): + self.blocks.append( + DITBuildingBlock( + block_type, + x_dim, + context_dim, + num_heads, + mlp_ratio, + window_sizes, + spatial_attn_win_size, + temporal_attn_win_size, + x_format=self.x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + n_views=n_views, + ) + ) + self.use_checkpoint = use_checkpoint + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_per_block_pos_emb, + ) + else: + return self._forward( + x, emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, extra_per_block_pos_emb + ) + + def _forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x = x + extra_per_block_pos_emb + for block in self.blocks: + x = block( + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + return x + + def set_memory_save(self, mode: bool = True): + # (qsh) to make fsdp happy! + #! IMPORTANT! + if mode: + self.forward = self.forward_with_memory_save + for block in self.blocks: + block.forward = block.forward_with_memory_save + else: + raise NotImplementedError("Not implemented yet.") + + def forward_with_memory_save( + self, + x_before_gate: torch.Tensor, + x_skip: torch.Tensor, + gate_L_B_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ): + for block in self.blocks: + gate_L_B_D, x_before_gate, x_skip = block.forward( + x_before_gate, + x_skip, + gate_L_B_D, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_per_block_pos_emb, + ) + extra_per_block_pos_emb = None + return gate_L_B_D, x_before_gate, x_skip diff --git a/cosmos_predict1/diffusion/training/module/position_embedding.py b/cosmos_predict1/diffusion/training/module/position_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..83625e8a2c6e59352c2786e9fbd699c7c13e2a36 --- /dev/null +++ b/cosmos_predict1/diffusion/training/module/position_embedding.py @@ -0,0 +1,932 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks + +from cosmos_predict1.diffusion.module.attention import normalize +from cosmos_predict1.diffusion.module.timm import trunc_normal_ +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_size_h, + grid_size_w, + grid_size_t, + spatial_interpolation_scale, + temporal_interpolation_scale, + concat=True, +): + grid_h = np.arange(grid_size_h, dtype=np.float32) / spatial_interpolation_scale + grid_w = np.arange(grid_size_w, dtype=np.float32) / spatial_interpolation_scale + grid_t = np.arange(grid_size_t, dtype=np.float32) / temporal_interpolation_scale + + grid = np.meshgrid(grid_w, grid_h, grid_t, indexing="ij") + grid = np.stack(grid, axis=0) + grid = grid.reshape(3, 1, grid_size_h, grid_size_w, grid_size_t) + + if concat: + per_axis = embed_dim // 3 + per_axis = (per_axis // 2) * 2 # make it even (for sin/cos split) + dim_h, dim_w = per_axis, per_axis + dim_t = embed_dim - dim_h - dim_w + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, grid[0]) # (H*W, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, grid[1]) # (H*W, D/3) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, grid[2]) # (H*W, D/3) + + return np.concatenate([emb_h, emb_w, emb_t], axis=1) # (H*W*T, D) + else: + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[0]) # (H*W) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[1]) # (H*W) + emb_t = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[2]) # (H*W) + + return emb_h + emb_w + emb_t # (H*W*T, D) + + +class VideoPositionEmb(nn.Module): + def __init__(self): + super().__init__() + self.cp_group = None + + def enable_context_parallel(self, cp_group: ProcessGroup): + self.cp_group = cp_group + + def disable_context_parallel(self): + self.cp_group = None + + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: + """ + With CP, the function assume that the input tensor is already split. It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + if self.cp_group is not None: + cp_ranks = get_process_group_ranks(self.cp_group) + cp_size = len(cp_ranks) + B, T, H, W, C = B_T_H_W_C + B_T_H_W_C = (B, T * cp_size, H, W, C) + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) + + if self.cp_group is not None: + if isinstance(self, VideoRopePosition3DEmb): + seq_dim = 0 + else: + seq_dim = 1 + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): + raise NotImplementedError + + +class SinCosPosEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + is_learnable: bool = False, + interpolation: Literal["crop", "resize", "crop_resize"] = "crop", + spatial_interpolation_scale=1.0, + temporal_interpolation_scale=1.0, + init_length_for_resize: int = 16, + **kwargs, + ): + """ + Args: + interpolation (str): "crop", "resize", "crop_resize". "crop" means we crop the positional embedding to the length of the input sequence. "resize" means we resize the positional embedding to the length of the input sequence. "crop_resize" (inference only) means we first crop the positional embedding to init_length_for_resize, then resize it to the length of the input sequence. + init_length_for_resize (int): used when interpolation is "crop_resize", where we "resize" embedding during inference for model trained with "crop". We first "crop" the pos_embed to this length (used during training), then run the "resize", default 16 + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + self.init_length_for_resize = init_length_for_resize + param = get_3d_sincos_pos_embed( + model_channels, len_h, len_w, len_t, spatial_interpolation_scale, temporal_interpolation_scale + ) + param = rearrange(param, "(h w t) c -> 1 t h w c", h=len_h, w=len_w) + if is_learnable: + self.pos_embed = nn.Parameter( + torch.from_numpy(param).float(), + ) + else: + self.register_buffer("pos_embed", torch.from_numpy(param).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + if self.interpolation == "crop": + return self.pos_embed[:, :T, :H, :W] + if self.interpolation == "resize": + return rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T), + mode="linear", + align_corners=False, + ), + "1 c h w t -> 1 t h w c", + ) + if self.interpolation == "crop_resize": + pos_embed_crop = self.pos_embed[:, : self.init_length_for_resize, :H, :W] # B,T,H,W,C + _, t, h, w, c = pos_embed_crop.shape + + pos_embed_crop_resize_t = rearrange( + F.interpolate( + rearrange(pos_embed_crop, "1 t h w c -> 1 (c h w) t"), + size=(T), + mode="linear", + ), + "1 (c h w) t -> 1 t h w c", + c=c, + h=h, + w=w, + ) + pos_embed_crop_resize = rearrange( + F.interpolate( + rearrange(pos_embed_crop_resize_t, "1 t h w c -> 1 (c t) h w"), + size=(H, W), + mode="bilinear", + ), + "1 (c t) h w -> 1 t h w c", + c=c, + ) + return pos_embed_crop_resize + + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class SinCosPosEmb_FPS_Aware(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + min_fps: int, # 1 for getty video + max_fps: int, # 120 for getty video + is_learnable: bool = False, + interpolation: str = "crop", + spatial_interpolation_scale=1.0, + temporal_interpolation_scale=1.0, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs # unused + super().__init__() + self.interpolation = interpolation + self.max_fps = max_fps + self.min_fps = min_fps + if self.interpolation == "crop": + param = get_3d_sincos_pos_embed( + model_channels, + len_h, + len_w, + len_t * int(max_fps / min_fps), + spatial_interpolation_scale, + temporal_interpolation_scale, + ) # should be max_seq_length * (max_fps / min_fps) + elif self.interpolation == "resize": + param = get_3d_sincos_pos_embed( + model_channels, len_h, len_w, len_t, spatial_interpolation_scale, temporal_interpolation_scale + ) # time embedding based min fps + else: + ValueError(f"Unknown interpolation method {self.interpolation}") + param = rearrange(param, "(h w t) c -> 1 t h w c", h=len_h, w=len_w) + if is_learnable: + self.pos_embed = nn.Parameter( + torch.from_numpy(param).float(), + ) + else: + self.register_buffer("pos_embed", torch.from_numpy(param).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + + if self.interpolation == "crop": + if T > 1: + return torch.cat( + [ + self.pos_embed[:, : (int(self.max_fps / curr_fps) * T) : int(self.max_fps / curr_fps), :H, :W] + for curr_fps in fps + ], + 0, + ) + else: + return self.pos_embed[:, :T, :H, :W] # image model + elif self.interpolation == "resize": + if T > 1: + return torch.cat( + [ + rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T * int(curr_fps / self.min_fps)), + mode="trilinear", + align_corners=True, # important: align corner need to be true + )[:, :, :H, :W, :T], + "1 c h w t -> 1 t h w c", + ) + for curr_fps in fps + ], + 0, + ) + else: + # grab self.pos_embed at time step 0 and resize spatially + return rearrange( + F.interpolate( + rearrange(self.pos_embed[:, 0, ::], "1 h w c -> 1 c h w"), + size=(H, W), + mode="bilinear", + align_corners=True, + ), + "1 c h w -> 1 h w c", + ) + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class LearnableEmb3D(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + interpolation: str = "crop", + is_learnable: bool = True, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs # unused + super().__init__() + assert is_learnable is True + self.interpolation = interpolation + self.pos_embed = nn.Parameter(torch.zeros(1, len_t, len_h, len_w, model_channels)) + trunc_normal_(self.pos_embed, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + if self.interpolation == "crop": + return self.pos_embed[:, :T, :H, :W] + if self.interpolation == "resize": + return rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T), + mode="linear", + align_corners=False, + ), + "1 c h w t -> 1 t h w c", + ) + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class LearnableEmb3D_FPS_Aware(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + min_fps: int, # 1 for getty video + max_fps: int, # 120 for getty video + interpolation: str = "crop", + is_learnable: bool = True, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + assert is_learnable is True + self.interpolation = interpolation + self.max_fps = max_fps + self.min_fps = min_fps + + if self.interpolation == "crop": + self.pos_embed = nn.Parameter( + torch.zeros(1, len_t * int(max_fps / min_fps), len_h, len_w, model_channels) + ) # should be max_seq_length * (max_fps / min_fps) + elif self.interpolation == "resize": + self.pos_embed = nn.Parameter( + torch.zeros(1, len_t, len_h, len_w, model_channels) + ) # time embedding based min fps + else: + ValueError(f"Unknown interpolation method {self.interpolation}") + + trunc_normal_(self.pos_embed, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + + if self.interpolation == "crop": + if T > 1: + return torch.cat( + [ + self.pos_embed[:, : (int(self.max_fps / curr_fps) * T) : int(self.max_fps / curr_fps), :H, :W] + for curr_fps in fps + ], + 0, + ) + else: + return self.pos_embed[:, :T, :H, :W] # image model + elif self.interpolation == "resize": + if T > 1: + return torch.cat( + [ + rearrange( + F.interpolate( + rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), + size=(H, W, T * int(curr_fps / self.min_fps)), + mode="trilinear", + align_corners=True, # important: align corner need to be true + )[:, :, :H, :W, :T], + "1 c h w t -> 1 t h w c", + ) + for curr_fps in fps + ], + 0, + ) + else: + # grab self.pos_embed at time step 0 and resize spatially + return rearrange( + F.interpolate( + rearrange(self.pos_embed[:, 0, ::], "1 h w c -> 1 c h w"), + size=(H, W), + mode="bilinear", + align_corners=True, + ), + "1 c h w -> 1 h w c", + ) + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class VideoRopePositionEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(len_h * len_w * len_t, dtype=torch.float)) + + self.register_buffer( + "dim_range", torch.arange(0, head_dim, 2)[: (head_dim // 2)].float().cuda() / head_dim, persistent=False + ) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], ntk_factor: float = 1.0): + theta = 10000.0 * ntk_factor + + # original_dtype = self.dim_range.dtype + freq = 1.0 / (theta ** self.dim_range.float()) + _, T, H, W, _ = B_T_H_W_C + length = T * H * W + emb_L_D = torch.outer(self.seq[:length], freq) + return rearrange(torch.cat([emb_L_D, emb_L_D], dim=-1), "l d -> l 1 1 d").float() + + +class VideoRopePosition3DEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + self.max_t = len_t + + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + self._dim_h = dim_h + self._dim_t = dim_t + + def reset_parameters(self) -> None: + if self.dim_spatial_range.device == torch.device("meta"): + return + + dim_h = self._dim_h + dim_t = self._dim_t + + self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device) + + self.dim_spatial_range = ( + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h + ) + self.dim_temporal_range = ( + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t + ) + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + uniform_fps = (fps is None) or (fps.min() == fps.max()) + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w}) configured for positional embedding. Please adjust the input size or increase the maximum dimensions in the model configuration." + half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) + half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None: # image case + assert T == 1, "T should be 1 for image batch." + half_emb_t = torch.outer(self.seq[:T], temporal_freqs) + else: + half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + + return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + + +class SinCosPosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + dim = model_channels + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + + # rescale pos id is equivalent to rescale frequency + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(len_h) * 1.0 / h_extrapolation_ratio) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(len_w) * 1.0 / w_extrapolation_ratio) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(len_t) * 1.0 / t_extrapolation_ratio) + + self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).float(), persistent=False) + self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).float(), persistent=False) + self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:T] + emb = torch.cat( + [ + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W), + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W), + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H), + ], + dim=-1, + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + return emb + + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + +class LearnablePosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + self.pos_emb_h = nn.Parameter(torch.zeros(len_h, model_channels)) + self.pos_emb_w = nn.Parameter(torch.zeros(len_w, model_channels)) + self.pos_emb_t = nn.Parameter(torch.zeros(len_t, model_channels)) + + trunc_normal_(self.pos_emb_h, std=0.02) + trunc_normal_(self.pos_emb_w, std=0.02) + trunc_normal_(self.pos_emb_t, std=0.02) + + def reset_parameters(self): + if self.pos_emb_h.device == torch.device("meta"): + return + + trunc_normal_(self.pos_emb_h, std=0.02) + trunc_normal_(self.pos_emb_w, std=0.02) + trunc_normal_(self.pos_emb_t, std=0.02) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, _ = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:T] + emb = ( + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) + + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + else: + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + return normalize(emb, dim=-1, eps=1e-6) + + +class MultiviewVideoPositionEmb(nn.Module): + def __init__( + self, + ): + super().__init__() + self.cp_group = None + + def enable_context_parallel(self, cp_group: ProcessGroup): + self.cp_group = cp_group + + def disable_context_parallel(self): + self.cp_group = None + + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: + """ + With CP, the function assume that the input tensor is already split. It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + if self.cp_group is not None: + cp_ranks = get_process_group_ranks(self.cp_group) + cp_size = len(cp_ranks) + B, T, H, W, C = B_T_H_W_C + B_T_H_W_C = (B, T * cp_size, H, W, C) + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) + + if self.cp_group is not None: + if isinstance(self, MultiviewVideoRopePosition3DEmb): + seq_dim = 1 + embeddings = rearrange(embeddings, "(V T) H W D -> V (T H W) 1 1 D", V=self.n_views).float() + # rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + embeddings = rearrange(embeddings, "V T 1 1 D -> (V T) 1 1 D", V=self.n_views).float() + else: + seq_dim = 1 + embeddings = rearrange(embeddings, "B (V T) H W C -> (B V) T H W C", V=self.n_views) + embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) + embeddings = rearrange(embeddings, "(B V) T H W C -> B (V T) H W C", V=self.n_views) + else: + if isinstance(self, MultiviewVideoRopePosition3DEmb): + embeddings = rearrange(embeddings, "t h w d -> (t h w) 1 1 d").float() + + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): + raise NotImplementedError + + +class MultiviewVideoRopePosition3DEmb(MultiviewVideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + n_views: int = 4, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + self.n_views = n_views + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def generate_embedding_for_batch( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + uniform_fps = (fps is None) or (fps.min() == fps.max()) + assert uniform_fps # only support uniform fps now + + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w}) configured for positional embedding. Please adjust the input size or increase the maximum dimensions in the model configuration." + half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) + half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None: # image case + assert T == 1, "T should be 1 for image batch." + half_emb_t = torch.outer(self.seq[:T], temporal_freqs) + else: + half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d -> t h w d", h=H, w=W), + repeat(half_emb_h, "h d -> t h w d", t=T, w=W), + repeat(half_emb_w, "w d -> t h w d", t=T, h=H), + ] + * 2, + dim=-1, + ) + + return em_T_H_W_D + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. The camera view dimension is merged in the T dimension + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time * Views, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. + + Returns: + Not specified in the original code snippet. + """ + + B, T, H, W, C = B_T_H_W_C + + single_view_B_T_H_W_C = (B, T // self.n_views, H, W, C) + em_T_H_W_D = torch.cat( + [ + self.generate_embedding_for_batch( + single_view_B_T_H_W_C, + fps=fps, + h_ntk_factor=h_ntk_factor, + w_ntk_factor=w_ntk_factor, + t_ntk_factor=t_ntk_factor, + ) + for item in range(self.n_views) + ], + dim=0, + ) + + return em_T_H_W_D + # return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() + + +class MultiviewSinCosPosEmbAxis(MultiviewVideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + n_views: int = 4, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + self.n_views = n_views + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + dim = model_channels + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + + # rescale pos id is equivalent to rescale frequency + emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(len_h) * 1.0 / h_extrapolation_ratio) + emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(len_w) * 1.0 / w_extrapolation_ratio) + emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(len_t) * 1.0 / t_extrapolation_ratio) + + self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).float(), persistent=False) + self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).float(), persistent=False) + self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).float(), persistent=False) + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: + B, T, H, W, C = B_T_H_W_C + + single_view_T = T // self.n_views + + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H] + emb_w_W = self.pos_emb_w[:W] + emb_t_T = self.pos_emb_t[:single_view_T] + emb = torch.cat( + [ + torch.cat( + [ + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W), + repeat(emb_h_H, "h d-> b t h w d", b=B, t=single_view_T, w=W), + repeat(emb_w_W, "w d-> b t h w d", b=B, t=single_view_T, h=H), + ], + dim=-1, + ) + for _ in range(self.n_views) + ], + 1, + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + return emb + + raise ValueError(f"Unknown interpolation method {self.interpolation}") diff --git a/cosmos_predict1/diffusion/training/module/pretrained_vae.py b/cosmos_predict1/diffusion/training/module/pretrained_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a925b9ecaf7f6b952c1a3edd24236962ebbc74 --- /dev/null +++ b/cosmos_predict1/diffusion/training/module/pretrained_vae.py @@ -0,0 +1,805 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch.nn.modules import Module + +from cosmos_predict1.diffusion.training.module.pretrained_vae_base import JITVAE, BaseVAE, StateDictVAE +from cosmos_predict1.utils import log + + +class VideoTokenizerInterface(ABC): + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + pass + + @abstractmethod + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + pass + + @abstractmethod + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + pass + + @property + @abstractmethod + def spatial_compression_factor(self): + pass + + @property + @abstractmethod + def temporal_compression_factor(self): + pass + + @property + @abstractmethod + def spatial_resolution(self): + pass + + @property + @abstractmethod + def pixel_chunk_duration(self): + pass + + @property + @abstractmethod + def latent_chunk_duration(self): + pass + + @property + def is_chunk_overlap(self): + return False + + +class BasePretrainedVideoTokenizer(ABC): + """ + Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing. + + Args: + pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. + temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing. + max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow. + max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow. + + The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`) + which define how video data is subdivided and compressed during the encoding and decoding processes. The + `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory + constraints. + """ + + def __init__( + self, + pixel_chunk_duration: int = 17, + temporal_compress_factor: int = 8, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + ): + self._pixel_chunk_duration = pixel_chunk_duration + self._temporal_compress_factor = temporal_compress_factor + self.max_enc_batch_size = max_enc_batch_size + self.max_dec_batch_size = max_dec_batch_size + + def register_mean_std(self, mean_std_fp: str) -> None: + latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) + latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] + + target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1] + + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: + """ + Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding + """ + B, C, T, H, W = state.shape + assert ( + T % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" + return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration) + + def transform_decode_state_shape(self, latent: torch.Tensor) -> None: + B, _, T, _, _ = latent.shape + assert ( + T % self.latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" + return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = state.shape + state = rearrange(state, "b c t h w -> (b t) c 1 h w") + B, C, T, H, W = state.shape + state = self.transform_encode_state_shape(state) + # use max_enc_batch_size to avoid OOM + if state.shape[0] > self.max_enc_batch_size: + latent = [] + for i in range(0, state.shape[0], self.max_enc_batch_size): + latent.append(super().encode(state[i : i + self.max_enc_batch_size])) + latent = torch.cat(latent, dim=0) + else: + latent = super().encode(state) + + latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T) + return latent + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode, + it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions. + + It can also decode single frame image data. + + Args: + latent (torch.Tensor): The latent space tensor containing encoded video data. + + Returns: + torch.Tensor: The decoded video tensor reconstructed from latent space. + """ + if self._temporal_compress_factor == 1: + _, _, origin_T, _, _ = latent.shape + latent = rearrange(latent, "b c t h w -> (b t) c 1 h w") + B, _, T, _, _ = latent.shape + latent = self.transform_decode_state_shape(latent) + # use max_enc_batch_size to avoid OOM + if latent.shape[0] > self.max_dec_batch_size: + state = [] + for i in range(0, latent.shape[0], self.max_dec_batch_size): + state.append(super().decode(latent[i : i + self.max_dec_batch_size])) + state = torch.cat(state, dim=0) + else: + state = super().decode(latent) + assert state.shape[2] == self.pixel_chunk_duration + state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) + if self._temporal_compress_factor == 1: + return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T) + return state + + @property + def pixel_chunk_duration(self) -> int: + return self._pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + # return self._latent_chunk_duration + assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, ( + f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration " + f"{self.latent_chunk_duration}" + ) + return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + + @property + def temporal_compression_factor(self): + return self._temporal_compress_factor + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" + return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_chunk_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" + return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration + + +class VideoJITTokenizer(BasePretrainedVideoTokenizer, JITVAE, VideoTokenizerInterface): + """ + Instance of BasePretrainedVideoVAE that loads encoder and decoder from JIT scripted module file + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_bf16: bool = True, + spatial_compression_factor: int = 16, + temporal_compression_factor: int = 8, + pixel_chunk_duration: int = 17, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + spatial_resolution: str = "720", + ): + super().__init__(pixel_chunk_duration, temporal_compression_factor, max_enc_batch_size, max_dec_batch_size) + super(BasePretrainedVideoTokenizer, self).__init__(enc_fp, dec_fp, name, mean_std_fp, latent_ch, False, is_bf16) + + self._spatial_compression_factor = spatial_compression_factor + self._spatial_resolution = spatial_resolution + + @property + def spatial_compression_factor(self): + return self._spatial_compression_factor + + @property + def spatial_resolution(self) -> str: + return self._spatial_resolution + + +class VideoStateDictTokenizer(BasePretrainedVideoTokenizer, StateDictVAE, VideoTokenizerInterface): + """ + Instance of BasePretrainedVideoVAE that loads encoder and decoder from state_dict checkpoint + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + vae: torch.nn.Module, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_bf16: bool = True, + spatial_compression_factor: int = 16, + temporal_compression_factor: int = 8, + pixel_chunk_duration: int = 17, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + spatial_resolution: str = "720", + ): + super().__init__(pixel_chunk_duration, temporal_compression_factor, max_enc_batch_size, max_dec_batch_size) + super(BasePretrainedVideoTokenizer, self).__init__( + enc_fp, dec_fp, vae, name, mean_std_fp, latent_ch, is_image=False, is_bf16=is_bf16 + ) + + self._spatial_compression_factor = spatial_compression_factor + self._spatial_resolution = spatial_resolution + + @property + def spatial_compression_factor(self): + return self._spatial_compression_factor + + @property + def spatial_resolution(self) -> str: + return self._spatial_resolution + + +class VideoJITVAEChunkWiseTokenizer(VideoJITTokenizer): + """ + Do temporal chunk wise encoding and decoding. + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + mean_std_fp: str, + spatial_compression_factor: int, + latent_ch: int = 16, + is_bf16: bool = True, + full_duration: int = 121, + chunk_duration: int = 49, + temporal_compression_factor: int = 8, + max_enc_batch_size: int = 8, + max_dec_batch_size: int = 4, + spatial_resolution="720", + overlap_size: int = 9, + ): + self._latent_chunk_duration = ( + chunk_duration - 1 + ) // temporal_compression_factor + 1 # need to set before super init + self._latent_full_duration = (full_duration - 1) // temporal_compression_factor + 1 + super().__init__( + enc_fp=enc_fp, + dec_fp=dec_fp, + name=name, + mean_std_fp=mean_std_fp, + latent_ch=latent_ch, + is_bf16=is_bf16, + pixel_chunk_duration=chunk_duration, + temporal_compression_factor=temporal_compression_factor, + max_enc_batch_size=max_enc_batch_size, + max_dec_batch_size=max_dec_batch_size, + spatial_resolution=spatial_resolution, + spatial_compression_factor=spatial_compression_factor, + ) + self.overlap_size = overlap_size + self.full_duration = full_duration + # make sure full_duration is divisible by chunk_duration with pre-set overlap size + assert (full_duration - overlap_size) % (chunk_duration - overlap_size) == 0 + + @property + def latent_chunk_duration(self) -> int: + return self._latent_chunk_duration + + @property + def latent_full_duration(self) -> int: + return self._latent_full_duration + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.full_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.full_duration}" + return num_pixel_frames // self.full_duration * self.latent_full_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_full_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_full_duration}" + return num_latent_frames // self.latent_full_duration * self.full_duration + + def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: + # This is a hack impl, should be improved later + return state + + def transform_decode_state_shape(self, latent: torch.Tensor) -> torch.Tensor: + # This is a hack impl, should be improved later + return latent + + def _impl_encode(self, state: torch.Tensor) -> torch.Tensor: + in_dtype = state.dtype + + latent_mean = self.latent_mean.to(in_dtype) + latent_std = self.latent_std.to(in_dtype) + encoded_state = self.encoder(state.to(self.dtype)) + if isinstance(encoded_state, torch.Tensor): + pass + elif isinstance(encoded_state, tuple): + assert isinstance(encoded_state[0], torch.Tensor) + encoded_state = encoded_state[0] + else: + raise ValueError("Invalid type of encoded state") + return (encoded_state.to(in_dtype) - latent_mean) / latent_std + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + + assert state.shape[2] == self.full_duration + + # Calculate the number of overlapping windows/chunks + # Each window has a duration of self.pixel_chunk_duration frames + # The overlap between consecutive windows is self.overlap_size frames + num_windows = (T - self.pixel_chunk_duration) // (self.pixel_chunk_duration - self.overlap_size) + # Calculate the total number of frames covered by the windows + num_windowed_frames = self.pixel_chunk_duration + num_windows * (self.pixel_chunk_duration - self.overlap_size) + + assert num_windowed_frames == T # only handle case where number frames can be separated equally + # Prepare a list to hold overlapping chunks of the input state + pack_list = [state[:, :, : self.pixel_chunk_duration]] + [ + state[ + :, + :, + (ii + 1) + * (self.pixel_chunk_duration - self.overlap_size) : (ii + 1) + * (self.pixel_chunk_duration - self.overlap_size) + + self.pixel_chunk_duration, + ] + for ii in range(num_windows) + ] + + latent = self._impl_encode(torch.cat(pack_list, dim=0)) + latent = rearrange(latent, "(n b) c t h w -> n b c t h w", b=B) + # Calculate the overlap size in the latent space, accounting for any temporal compression + # For example, if the network downsamples temporally by a factor of 4, adjust the overlap accordingly + overlap_latent = (self.overlap_size - 1) // self.temporal_compression_factor + 1 + # Concatenate the latent representations from each chunk/window + # For the first chunk, include all latent frames + # For subsequent chunks, exclude the overlapping latent frames at the beginning + out = torch.cat([latent[0]] + [latent[i, :, :, overlap_latent:] for i in range(1, len(latent))], dim=2) + return out + + @torch.no_grad() + def maybe_pad_latent(self, latent: torch.Tensor) -> tuple[torch.Tensor, int]: + """Since the decoder expect the latent to be window_size + (window_size - decode_overlap_size) * N, we need to pad the latent to match the expected size + Args: + latent (torch.Tensor): [B, C, T, H, W] + Returns: + latent: torch.Tensor, the padded latent + padding_t: int, the number of padding latent t + """ + + # Calculate the overlap size and window size in the latent space, considering any temporal compression + decode_overlap_size = (self.overlap_size - 1) // self.temporal_compression_factor + 1 + # Calculate the number of windows/chunks for decoding + window_size = (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + B, C, current_latent_t, H, W = latent.shape + + if current_latent_t < window_size: + # If the current latent tensor is smaller than the window size, pad it to the window size + target_latent_t = window_size + else: + # Calculate the target latent frame number for decoding + target_latent_t = window_size + math.ceil( + (current_latent_t - window_size) / (window_size - decode_overlap_size) + ) * (window_size - decode_overlap_size) + + padding_t = target_latent_t - current_latent_t + if padding_t != 0: + log.info( + f"Padding latent from {current_latent_t} to {target_latent_t} for decoding purpose. current window_size: {window_size}, decode_overlap_size: {decode_overlap_size}" + ) + padding = latent.new_zeros(B, C, padding_t, H, W) + latent = torch.cat([latent, padding], dim=2).contiguous() + return latent, padding_t + + @torch.no_grad() + def decode(self, state: torch.Tensor) -> torch.Tensor: + state, padding_t = self.maybe_pad_latent(state) + B, C, num_latents, H, W = state.shape + + # Calculate the overlap size and window size in the latent space, considering any temporal compression + decode_overlap_size = (self.overlap_size - 1) // self.temporal_compression_factor + 1 + # Calculate the number of windows/chunks for decoding + window_size = (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 + + num_windows = (num_latents - window_size) // (window_size - decode_overlap_size) + 1 + decoded_frames = [] + # Start decoding with the initial window of latent frames + current_state = state[:, :, :window_size] + for i in range(num_windows): + # Decode the current window to get the reconstructed frames + window_frames = super().decode(current_state) + decoded_frames.append(window_frames) + # Re-encode the overlapping frames at the end of the decoded window to obtain the last latent frame + # This is necessary due to the casual first frame design + last_latent = self._impl_encode(window_frames[:, :, -self.overlap_size : -self.overlap_size + 1])[:, :, 0:1] + # Calculate the start and end indices for the next chunk of latent frames + start_idx = window_size + i * (window_size - decode_overlap_size) - decode_overlap_size + 1 + end_idx = start_idx + window_size - 1 + # Prepare the next state by concatenating the last latent frame with the next chunk of latent frames + current_state = torch.cat([last_latent, state[:, :, start_idx:end_idx]], dim=2) + # Remove overlapping frames (e.g., 17 frames) from all windows except the first one. + for i in range(1, num_windows): + decoded_frames[i] = decoded_frames[i][:, :, self.overlap_size :] + video_tensor = torch.cat(decoded_frames, dim=2) + return video_tensor + + @property + def is_chunk_overlap(self): + return True + + +class DebugMeanStdVideoJITVAE(VideoJITTokenizer): + """ + A class for one + """ + + def register_mean_std(self, mean_std_fp: str) -> None: + target_shape = [1, self.latent_ch, 1, 1, 1] + self.register_buffer( + "latent_mean", + # latent_mean.to(self.dtype).reshape(*target_shape), + torch.zeros(*target_shape, dtype=self.dtype), + persistent=False, + ) + self.register_buffer( + "latent_std", + torch.ones(*target_shape, dtype=self.dtype), + persistent=False, + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + return JITVAE.encode(self, state) + return super().encode(state) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, _, T, _, _ = latent.shape + if T == 1: + return JITVAE.decode(self, latent) + return super().decode(latent) + + +class DebugMeanStdVideoJITVAEChunkWiseTokenizer(VideoJITVAEChunkWiseTokenizer): + def register_mean_std(self, mean_std_fp: str) -> None: + target_shape = [1, self.latent_ch, 1, 1, 1] + self.register_buffer( + "latent_mean", + # latent_mean.to(self.dtype).reshape(*target_shape), + torch.zeros(*target_shape, dtype=self.dtype), + persistent=False, + ) + self.register_buffer( + "latent_std", + torch.ones(*target_shape, dtype=self.dtype), + persistent=False, + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + return JITVAE.encode(self, state) + return super().encode(state) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, _, T, _, _ = latent.shape + if T == 1: + return JITVAE.decode(self, latent) + return super().decode(latent) + + +class JointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface): + def __init__( + self, + image_vae: torch.nn.Module, + video_vae: torch.nn.Module, + name: str, + latent_ch: int = 16, + squeeze_for_image: bool = True, + ): + super().__init__(latent_ch, name) + self.image_vae = image_vae + self.video_vae = video_vae + self.squeeze_for_image = squeeze_for_image + + def encode_image(self, state: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.encode(state.squeeze(2)).unsqueeze(2) + return self.image_vae.encode(state) + + def decode_image(self, latent: torch.Tensor) -> torch.Tensor: + if self.squeeze_for_image: + return self.image_vae.decode(latent.squeeze(2)).unsqueeze(2) + return self.image_vae.decode(latent) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + return self.encode_image(state) + + return self.video_vae.encode(state) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = latent.shape + if T == 1: + return self.decode_image(latent) + return self.video_vae.decode(latent) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.image_vae.reset_dtype() + self.video_vae.reset_dtype() + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + return self.video_vae.get_latent_num_frames(num_pixel_frames) + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + return self.video_vae.get_pixel_num_frames(num_latent_frames) + + @property + def spatial_compression_factor(self): + return self.video_vae.spatial_compression_factor + + @property + def temporal_compression_factor(self): + return self.video_vae.temporal_compression_factor + + @property + def spatial_resolution(self) -> str: + return self.video_vae.spatial_resolution + + @property + def pixel_chunk_duration(self) -> int: + return self.video_vae.pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + return self.video_vae.latent_chunk_duration + + +class JointImageVideoSharedJITTokenizer(JointImageVideoTokenizer): + """ + First version of the ImageVideoVAE trained with Fitsum. + We have to use seperate mean and std for image and video due to non-causal nature of the model. + """ + + def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16): + super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False) + assert isinstance(image_vae, JITVAE) + assert isinstance( + video_vae, VideoJITTokenizer + ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}" + # a hack to make the image_vae and video_vae share the same encoder and decoder + self.image_vae.encoder = self.video_vae.encoder + self.image_vae.decoder = self.video_vae.decoder + + +class JointImageVideoStateDictTokenizer(JointImageVideoTokenizer): + """ + Copy of ImageVideoVAE1 that uses plain torch.nn.Module instead of JITed one so + that it can be used witch torch.compile() + """ + + def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16): + super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False) + + assert isinstance(image_vae, StateDictVAE) + assert isinstance(video_vae, VideoStateDictTokenizer) + # a hack to make the image_vae and video_vae share the same encoder and decoder + + # nn.Module + del self.image_vae.vae + # Just method + del self.image_vae.encoder + # Just method + del self.image_vae.decoder + + self.image_vae.vae = self.video_vae.vae + self.image_vae.encoder = self.video_vae.encoder + self.image_vae.decoder = self.video_vae.decoder + + +class DummyJointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface): + def __init__( + self, + name: str = "dummy_joint_image_video", + pixel_ch: int = 3, + latent_ch: int = 16, + pixel_chunk_duration: int = 17, + latent_chunk_duration: int = 3, + spatial_compression_factor: int = 16, + temporal_compression_factor: int = 8, + spatial_resolution: str = "720", + ): + self.pixel_ch = pixel_ch + self._pixel_chunk_duration = pixel_chunk_duration + self._latent_chunk_duration = latent_chunk_duration + self._spatial_compression_factor = spatial_compression_factor + self._temporal_compression_factor = temporal_compression_factor + self._spatial_resolution = spatial_resolution + super().__init__(latent_ch, name) + + @property + def spatial_compression_factor(self): + return self._spatial_compression_factor + + @property + def temporal_compression_factor(self): + return self._temporal_compression_factor + + @property + def spatial_resolution(self) -> str: + return self._spatial_resolution + + @property + def pixel_chunk_duration(self) -> int: + return self._pixel_chunk_duration + + @property + def latent_chunk_duration(self) -> int: + return self._latent_chunk_duration + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + if num_pixel_frames == 1: + return 1 + assert ( + num_pixel_frames % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" + return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + if num_latent_frames == 1: + return 1 + assert ( + num_latent_frames % self.latent_chunk_duration == 0 + ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" + return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration + + def encode(self, state: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = state.shape + if T == 1: + state_B_T_C_H_W = F.interpolate( + rearrange(state, "b c t h w -> b t c h w"), + size=(self.latent_ch, H // self.spatial_compression_factor, W // self.spatial_compression_factor), + mode="trilinear", + align_corners=False, + ) + return rearrange(state_B_T_C_H_W, "b t c h w -> b c t h w").contiguous() + assert ( + T % self.pixel_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" + num_frames = T // self.pixel_chunk_duration * self.latent_chunk_duration + + state_B_C_T_H_W = F.interpolate( + state, + size=(self.latent_ch, H // self.spatial_compression_factor, W // self.spatial_compression_factor), + mode="trilinear", + align_corners=False, + ) + state_B_H_W_T_C = rearrange(state_B_C_T_H_W, "b c t h w -> b h w t c") + state_B_H_W_T_C = F.interpolate( + state_B_H_W_T_C, + size=(W // self.spatial_compression_factor, num_frames, self.latent_ch), + mode="trilinear", + align_corners=False, + ) + return rearrange(state_B_H_W_T_C, "b h w t c -> b c t h w").contiguous() + + def decode(self, latent: torch.Tensor) -> torch.Tensor: + B, C, T, H, W = latent.shape + if T == 1: + latent_B_T_C_H_W = F.interpolate( + rearrange(latent, "b c t h w -> b t c h w"), + size=(self.pixel_ch, H * self.spatial_compression_factor, W * self.spatial_compression_factor), + mode="trilinear", + align_corners=False, + ) + return rearrange(latent_B_T_C_H_W, "b t c h w -> b c t h w").contiguous() + + assert ( + T % self.latent_chunk_duration == 0 + ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" + num_frames = T * self.pixel_chunk_duration // self.latent_chunk_duration + + latent_B_H_W_T_C = rearrange(latent, "b c t h w -> b h w t c") + latent_B_H_W_T_C = F.interpolate( + latent_B_H_W_T_C, + size=(W * self.spatial_compression_factor, num_frames, self.pixel_ch), + mode="trilinear", + align_corners=False, + ) + latent_B_C_T_H_W = rearrange(latent_B_H_W_T_C, "b h w t c -> b c t h w") + + state = F.interpolate( + latent_B_C_T_H_W, + size=(num_frames, H * self.spatial_compression_factor, W * self.spatial_compression_factor), + mode="trilinear", + align_corners=False, + ) + + return state.contiguous() diff --git a/cosmos_predict1/diffusion/training/module/pretrained_vae_base.py b/cosmos_predict1/diffusion/training/module/pretrained_vae_base.py new file mode 100644 index 0000000000000000000000000000000000000000..08570cbdcc9b8d9875ee907a28001aa5193d8fc5 --- /dev/null +++ b/cosmos_predict1/diffusion/training/module/pretrained_vae_base.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod + +import torch +import torch.nn.functional as F + +from cosmos_predict1.utils.distributed import rank0_first +from cosmos_predict1.utils.misc import load_from_s3_with_cache + + +class BaseVAE(torch.nn.Module, ABC): + """ + Abstract base class for a Variational Autoencoder (VAE). + + All subclasses should implement the methods to define the behavior for encoding + and decoding, along with specifying the latent channel size. + """ + + def __init__(self, channel: int = 3, name: str = "vae"): + super().__init__() + self.channel = channel + self.name = name + + @property + def latent_ch(self) -> int: + """ + Returns the number of latent channels in the VAE. + """ + return self.channel + + @abstractmethod + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encodes the input tensor into a latent representation. + + Args: + - state (torch.Tensor): The input tensor to encode. + + Returns: + - torch.Tensor: The encoded latent tensor. + """ + pass + + @abstractmethod + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decodes the latent representation back to the original space. + + Args: + - latent (torch.Tensor): The latent tensor to decode. + + Returns: + - torch.Tensor: The decoded tensor. + """ + pass + + @property + def spatial_compression_factor(self) -> int: + """ + Returns the spatial reduction factor for the VAE. + """ + raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") + + +class BasePretrainedImageVAE(BaseVAE): + """ + A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values + from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + Derived classes should load pre-trained encoder and decoder components from a remote store + + Attributes: + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ) -> None: + super().__init__(latent_ch, name) + dtype = torch.bfloat16 if is_bf16 else torch.float32 + self.dtype = dtype + self.is_image = is_image + self.mean_std_fp = mean_std_fp + self.name = name + + self.backend_args = None + + self.register_mean_std(mean_std_fp) + + def register_mean_std(self, mean_std_fp: str) -> None: + latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) + target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1] + self.register_buffer( + "latent_mean", + latent_mean.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + self.register_buffer( + "latent_std", + latent_std.to(self.dtype).reshape(*target_shape), + persistent=False, + ) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + Encode the input state to latent space; also handle the dtype conversion, mean and std scaling + """ + in_dtype = state.dtype + latent_mean = self.latent_mean.to(in_dtype) + latent_std = self.latent_std.to(in_dtype) + encoded_state = self.encoder(state.to(self.dtype)) + if isinstance(encoded_state, torch.Tensor): + pass + elif isinstance(encoded_state, tuple): + assert isinstance(encoded_state[0], torch.Tensor) + encoded_state = encoded_state[0] + else: + raise ValueError("Invalid type of encoded state") + return (encoded_state.to(in_dtype) - latent_mean) / latent_std + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """ + Decode the input latent to state; also handle the dtype conversion, mean and std scaling + """ + in_dtype = latent.dtype + latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype) + return self.decoder(latent.to(self.dtype)).to(in_dtype) + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.decoder.to(self.dtype) + self.encoder.to(self.dtype) + + +class JITVAE(BasePretrainedImageVAE): + """ + A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder + and decoder components from a remote store, handles data type conversions, and normalization + using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The JIT compiled encoder loaded from storage. + decoder (Module): The JIT compiled decoder loaded from storage. + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + name (str): Name of the model, used for differentiating cache file paths. + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ): + super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) + self.load_encoder(enc_fp) + self.load_decoder(dec_fp) + + def load_encoder(self, enc_fp: str) -> None: + """ + Load the encoder from the remote store. + + Args: + - enc_fp (str): File path to the encoder's JIT file on the remote store. + """ + self.encoder = torch.jit.load(enc_fp, map_location="cuda") + self.encoder.eval() + for param in self.encoder.parameters(): + param.requires_grad = False + self.encoder.to(self.dtype) + + def load_decoder(self, dec_fp: str) -> None: + """ + Load the decoder from the remote store. + + Args: + - dec_fp (str): File path to the decoder's JIT file on the remote store. + """ + self.decoder = torch.jit.load(dec_fp, map_location="cuda") + self.decoder.eval() + for param in self.decoder.parameters(): + param.requires_grad = False + self.decoder.to(self.dtype) + + +class StateDictVAE(BasePretrainedImageVAE): + """ + A Variational Autoencoder (VAE) that loads pre-trained weights into + provided encoder and decoder components from a remote store, handles data type conversions, + and normalization using provided mean and standard deviation values for latent space representation. + + Attributes: + encoder (Module): The encoder with weights loaded from storage. + decoder (Module): The decoder with weights loaded from storage. + latent_mean (Tensor): The mean used for normalizing the latent representation. + latent_std (Tensor): The standard deviation used for normalizing the latent representation. + dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. + + Args: + enc_fp (str): File path to the encoder's JIT file on the remote store. + dec_fp (str): File path to the decoder's JIT file on the remote store. + vae (Module): Instance of VAE with not loaded weights + name (str): Name of the model, used for differentiating cache file paths. + mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. + latent_ch (int, optional): Number of latent channels (default is 16). + is_image (bool, optional): Flag to indicate whether the output is an image (default is True). + is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). + """ + + def __init__( + self, + enc_fp: str, + dec_fp: str, + vae: torch.nn.Module, + name: str, + mean_std_fp: str, + latent_ch: int = 16, + is_image: bool = True, + is_bf16: bool = True, + ): + super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) + + self.load_encoder_and_decoder(enc_fp, dec_fp, vae) + + def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, vae: torch.nn.Module) -> None: + """ + Load the encoder from the remote store. + + Args: + - vae_fp (str): File path to the vae's state dict file on the remote store. + - vae (str): VAE module into which weights will be loaded. + """ + state_dict_enc = load_from_s3_with_cache( + enc_fp, + f"vae/{self.name}_enc.jit", + easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, + backend_args=self.backend_args, + ) + + state_dict_dec = load_from_s3_with_cache( + dec_fp, + f"vae/{self.name}_dec.jit", + easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, + backend_args=self.backend_args, + ) + + jit_weights_state_dict = state_dict_enc.state_dict() | state_dict_dec.state_dict() + jit_weights_state_dict = { + k: v + for k, v in jit_weights_state_dict.items() + # Global variables captured by JIT + if k + not in ( + "encoder.patcher.wavelets", + "encoder.patcher._arange", + "decoder.unpatcher.wavelets", + "decoder.unpatcher._arange", + ) + } + + vae.load_state_dict(jit_weights_state_dict) + vae.eval() + for param in vae.parameters(): + param.requires_grad = False + vae.to(self.dtype) + + self.vae = vae + self.encoder = self.vae.encode + self.decoder = self.vae.decode + + def reset_dtype(self, *args, **kwargs): + """ + Resets the data type of the encoder and decoder to the model's default data type. + + Args: + *args, **kwargs: Unused, present to allow flexibility in method calls. + """ + del args, kwargs + self.vae.to(self.dtype) + + +class SDVAE(BaseVAE): + def __init__(self, batch_size=16, count_std: bool = False, is_downsample: bool = True) -> None: + super().__init__(channel=4, name="sd_vae") + self.dtype = torch.bfloat16 + self.register_buffer( + "scale", + torch.tensor([4.17, 4.62, 3.71, 3.28], dtype=self.dtype).reciprocal().reshape(1, -1, 1, 1), + persistent=False, + ) + self.register_buffer( + "bias", + -1.0 * torch.tensor([5.81, 3.25, 0.12, -2.15], dtype=self.dtype).reshape(1, -1, 1, 1) * self.scale, + persistent=False, + ) + self.batch_size = batch_size + self.count_std = count_std + self.is_downsample = is_downsample + self.load_vae() + self.reset_dtype() + + def reset_dtype(self, *args, **kwargs): + del args, kwargs + self.vae.to(self.dtype) + + @rank0_first + def load_vae(self) -> None: + os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" + os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" + import diffusers + + vae_name = "stabilityai/sd-vae-ft-mse" + try: + vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, local_files_only=True) + except: # noqa: E722 + # Could not load the model from cache; try without local_files_only. + vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name) + self.vae = vae.eval().requires_grad_(False) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """ + state : pixel range [-1, 1] + """ + if self.is_downsample: + _h, _w = state.shape[-2:] + state = F.interpolate(state, size=(_h // 2, _w // 2), mode="bilinear", align_corners=False) + in_dtype = state.dtype + state = state.to(self.dtype) + state = (state + 1.0) / 2.0 + latent_dist = self.vae.encode(state)["latent_dist"] + mean, std = latent_dist.mean, latent_dist.std + if self.count_std: + latent = mean + torch.randn_like(mean) * std + else: + latent = mean + latent = latent * self.scale + latent = latent + self.bias + return latent.to(in_dtype) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + in_dtype = latent.dtype + latent = latent.to(self.dtype) + latent = latent - self.bias + latent = latent / self.scale + latent = torch.cat([self.vae.decode(batch)["sample"] for batch in latent.split(self.batch_size)]) + if self.is_downsample: + _h, _w = latent.shape[-2:] + latent = F.interpolate(latent, size=(_h * 2, _w * 2), mode="bilinear", align_corners=False) + return latent.to(in_dtype) * 2 - 1.0 + + @property + def spatial_compression_factor(self) -> int: + return 8 diff --git a/cosmos_predict1/diffusion/training/modules/edm_sde.py b/cosmos_predict1/diffusion/training/modules/edm_sde.py new file mode 100644 index 0000000000000000000000000000000000000000..3d08a8229f03c9fdd6a8d905ad4543fe5fe5238a --- /dev/null +++ b/cosmos_predict1/diffusion/training/modules/edm_sde.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from statistics import NormalDist + +import numpy as np +import torch + + +class EDMSDE: + def __init__( + self, + p_mean: float = -1.2, + p_std: float = 1.2, + sigma_max: float = 80.0, + sigma_min: float = 0.002, + ): + self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) + self.sigma_max = sigma_max + self.sigma_min = sigma_min + + def sample_t(self, batch_size: int) -> torch.Tensor: + cdf_vals = np.random.uniform(size=(batch_size)) + samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] + + log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") + return torch.exp(log_sigma) + + def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """This is trivial in the base class, but may be used by derived classes in a more interesting way""" + return x0, sigma diff --git a/cosmos_predict1/diffusion/training/networks/general_dit.py b/cosmos_predict1/diffusion/training/networks/general_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6fb610208ea6c108bd951818ea1a74b760715e --- /dev/null +++ b/cosmos_predict1/diffusion/training/networks/general_dit.py @@ -0,0 +1,1022 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. +It allows us easy to switch building blocks used and their order. Its instantiation includes +* transformer on fully flattened tokens +* factored spatial and temporal attention +* factored non-overlap spatial and temporal attention +* mixing of above attention types + +Limitations: + +* In favor of simplicity and cleanness, many ops are not fused and we can do better +* such as combining mutiple adaln MLPs into one inside one transformer block. +* we use reshape heavily, which may be not efficient when its occurs unnecessary CUDA memory copy + +Purpose: +* A prototype for testing different attention types and their combinations +* Idealy, we want to know where we should allocate our resources / FLOPS / memory via extensive empirical studies +""" + +from collections.abc import Container +from typing import List, Optional, Tuple + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import nn +from torch.distributed import ProcessGroup, get_process_group_ranks +from torchvision import transforms + +from cosmos_predict1.diffusion.module.attention import get_normalization +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.module.blocks import ( + DITBuildingBlock, + FinalLayer, + GeneralDITTransformerBlock, + PatchEmbed, + SDXLTimestepEmbedding, + SDXLTimesteps, +) +from cosmos_predict1.diffusion.training.module.position_embedding import ( + LearnableEmb3D, + LearnableEmb3D_FPS_Aware, + LearnablePosEmbAxis, + SinCosPosEmb, + SinCosPosEmb_FPS_Aware, + SinCosPosEmbAxis, + VideoRopePosition3DEmb, + VideoRopePositionEmb, +) +from cosmos_predict1.diffusion.training.tensor_parallel import gather_along_first_dim, scatter_along_first_dim +from cosmos_predict1.utils import log + + +class GeneralDIT(nn.Module): + """ + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + Attributes: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple of int): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + block_config (str): Configuration of the transformer block, e.g., 'FA-CA-MLP', means + full attention, cross attention, and MLP in sequence in one transformer block. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of residual blocks per resolution in the transformer. + num_heads (int): Number of heads in the multi-head self-attention layers. + spatial_attn_win_size (int): Window size for the spatial attention mechanism. + temporal_attn_win_size (int): Window size for the temporal attention mechanism. + mlp_ratio (float): Expansion ratio for the MLP (multi-layer perceptron) blocks in the transformer. + use_memory_save (bool): If True, utilizes checkpointing to reduce memory usage during training. (Deprecated) + use_checkpoint (bool): If True, utilizes checkpointing to reduce memory usage during training for all blocks. + crossattn_emb_channels (int): Number of embedding channels used in the cross-attention layers. + use_cross_attn_mask (bool): If True, applies a mask during cross-attention operations to manage sequence alignment. + pos_emb_cls (str): Type of positional embeddings used ('sincos' for sinusoidal or other types). + pos_emb_learnable (bool): Specifies if positional embeddings are learnable. + pos_emb_interpolation (str): Method used for interpolating positional embeddings, e.g., 'crop' for cropping adjustments. + block_x_format (str, optional): The format of the input tensor for the transformer block. Defaults to "BTHWD". Only support 'BTHWD' and 'THWBD'. + legacy_patch_emb (bool): If True, applies 3D convolutional layers for video inputs, otherwise, use Linear! This is for backward compatibility. + rope_h_extrapolation_ratio (float): Ratio of the height extrapolation for the rope positional embedding. + rope_w_extrapolation_ratio (float): Ratio of the width extrapolation for the rope positional embedding. + rope_t_extrapolation_ratio (float): Ratio of the temporal extrapolation for the rope positional embedding. + Note: + block_config support block type: + * spatial_sa, ssa: spatial self attention + * temporal_sa, tsa: temporal self attention + * cross_attn, ca: cross attention + * full_attn: full attention on all flatten tokens + * mlp, ff: feed forward block + * use '-' to separate different building blocks, e.g., 'FA-CA-MLP' means full attention, cross attention, and MLP in sequence in one transformer block. + + Example: + >>> # full attention, cross attention, and MLP + >>> option1_block_config = 'FA-CA-MLP' + >>> model_1 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=1, + block_config=option1_block_config + ) + >>> option2_block_config = 'SSA-CA-MLP-TSA-CA-MLP' + >>> model_2 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=1, + block_config=option2_block_config + ) + >>> # option3 model + >>> model_3 = GeneralDIT( + max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, + patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, + num_heads=16, mlp_ratio=4.0, + spatial_attn_win_size=1, temporal_attn_win_size=2, + block_config=option2_block_config + ) + >>> # Process input tensor through the model + >>> output = model(input_tensor) + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + block_config: str = "FA-CA-MLP", + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + window_block_indexes: list = [], # index for window attention block + window_sizes: list = [], # window size for window attention block in the order of T, H, W + spatial_attn_win_size: int = 1, + temporal_attn_win_size: int = 1, + mlp_ratio: float = 4.0, + use_memory_save: bool = False, + use_checkpoint: bool = False, + block_x_format: str = "BTHWD", + # cross attention settings + crossattn_emb_channels: int = 1024, + use_cross_attn_mask: bool = False, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, # 1 for getty video + max_fps: int = 30, # 120 for getty video but let's use 30 + additional_timestamp_channels: dict = None, # Follow SDXL, in format of {condition_name : dimension} + affline_emb_norm: bool = False, # whether or not to normalize the affine embedding + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + layer_mask: list = None, # whether or not a layer is used. For controlnet encoder + legacy_patch_emb: bool = True, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = True, + extra_per_block_abs_pos_emb_type: str = "learnable", + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + ) -> None: + super().__init__() + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.use_cross_attn_mask = use_cross_attn_mask + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.additional_timestamp_channels = additional_timestamp_channels + self.affline_emb_norm = affline_emb_norm + self.legacy_patch_emb = legacy_patch_emb + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower() + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + + self.build_patch_embed() + self.build_pos_embed() + self.cp_group = None + self.sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + self.block_x_format = block_x_format + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + SDXLTimesteps(model_channels), + SDXLTimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), + ) + + self.blocks = nn.ModuleDict() + self.block_config = block_config + self.use_memory_save = use_memory_save + self.use_checkpoint = use_checkpoint + + assert ( + len(window_block_indexes) == 0 or block_config == "FA-CA-MLP" + ), "Block config must be FA-CA-MLP if using a combination of window attention and global attention" + + layer_mask = [False] * num_blocks if layer_mask is None else layer_mask + assert ( + len(layer_mask) == num_blocks + ), f"Layer mask length {len(layer_mask)} does not match num_blocks {num_blocks}" + for idx in range(num_blocks): + if layer_mask[idx]: + continue + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + block_config=block_config, + window_sizes=( + window_sizes if idx in window_block_indexes else [] + ), # There will be bug if using "WA-CA-MLP" + mlp_ratio=mlp_ratio, + spatial_attn_win_size=spatial_attn_win_size, + temporal_attn_win_size=temporal_attn_win_size, + x_format=self.block_x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + use_checkpoint=use_checkpoint, + ) + + self.build_decode_head() + self.build_additional_timestamp_embedder() + if self.affline_emb_norm: + log.critical("Building affine embedding normalization layer") + self.affline_norm = get_normalization("R", model_channels) + else: + self.affline_norm = nn.Identity() + self.init_weights() + + if self.use_memory_save: + log.critical("Using checkpointing to save memory! only verified in 14B base model training!") + for block in self.blocks.values(): + block.set_memory_save() + + def init_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding + nn.init.normal_(self.t_embedder[1].linear_1.weight, std=0.02) + if self.t_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.t_embedder[1].linear_2.weight, std=0.02) + if self.t_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.t_embedder[1].linear_2.bias, 0) + + # Zero-out adaLN modulation layers in DiT blocks: + for transformer_block in self.blocks.values(): + for block in transformer_block.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + if block.adaLN_modulation[-1].bias is not None: + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Tensor parallel + if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: + self.initialize_tensor_parallel_weights() + + def initialize_tensor_parallel_weights(self): + """ + Initialize weights for tensor parallel layers. + + This function performs the following steps: + 1. Retrieves the tensor parallel rank. + 2. Saves the current random state. + 3. Sets a new random seed based on the tensor parallel rank. + 4. Initializes weights for attention and MLP layers in each block. + 5. Restores the original random state. + + The use of different random seeds for each rank ensures + unique initializations across parallel processes. + """ + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + # Save the current random state + rng_state = torch.get_rng_state() + + # Set a new random seed based on the tensor parallel rank + torch.manual_seed(tp_rank) + + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: + # Initialize weights for attention layers + torch.nn.init.xavier_uniform_(layer.block.attn.to_q[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_k[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_v[0].weight) + torch.nn.init.xavier_uniform_(layer.block.attn.to_out[0].weight) + elif layer.block_type in ["mlp", "ff"]: + # Initialize weights for MLP layers + torch.nn.init.xavier_uniform_(layer.block.layer1.weight) + torch.nn.init.xavier_uniform_(layer.block.layer2.weight) + else: + raise ValueError(f"Unknown block type {layer.block_type}") + + # Restore the original random state + torch.set_rng_state(rng_state) + + def build_decode_head(self): + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + ) + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + ) + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + keep_spatio=True, + legacy_patch_emb=self.legacy_patch_emb, + ) + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + if self.legacy_patch_emb: + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def build_additional_timestamp_embedder(self): + if self.additional_timestamp_channels: + self.additional_timestamp_embedder = nn.ModuleDict() + for cond_name, cond_emb_channels in self.additional_timestamp_channels.items(): + log.critical( + f"Building additional timestamp embedder for {cond_name} with {cond_emb_channels} channels" + ) + self.additional_timestamp_embedder[cond_name] = nn.Sequential( + SDXLTimesteps(cond_emb_channels), + SDXLTimestepEmbedding(cond_emb_channels, cond_emb_channels), + ) + + def prepare_additional_timestamp_embedder(self, **kwargs): + condition_concat = [] + + for cond_name, embedder in self.additional_timestamp_embedder.items(): + condition_concat.append(embedder(kwargs[cond_name])[0]) + embedding = torch.cat(condition_concat, dim=1) + if embedding.shape[1] < self.model_channels: + embedding = nn.functional.pad(embedding, (0, self.model_channels - embedding.shape[1])) + return embedding + + def build_pos_embed(self): + if self.pos_emb_cls == "sincos": + cls_type = SinCosPosEmb + elif self.pos_emb_cls == "learnable": + cls_type = LearnableEmb3D + elif self.pos_emb_cls == "sincos_fps_aware": + cls_type = SinCosPosEmb_FPS_Aware + elif self.pos_emb_cls == "learnable_fps_aware": + cls_type = LearnableEmb3D_FPS_Aware + elif self.pos_emb_cls == "rope": + cls_type = VideoRopePositionEmb + elif self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + assert self.extra_per_block_abs_pos_emb is True, "extra_per_block_abs_pos_emb must be True" + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "learnable", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + self.extra_pos_embedder = LearnablePosEmbAxis(**kwargs) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb + + def decoder_head( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] + crossattn_mask: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + del crossattn_emb, crossattn_mask + B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape + x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D") + x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D) + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + return x_B_D_T_H_W + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def forward_blocks_regular( + self, + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ): + features = [] + for name, block in self.blocks.items(): + assert ( + self.blocks["block0"].x_format == block.x_format + ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + x = block( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + ) + + # Extract features + block_idx = int(name.split("block")[-1]) + if block_idx in feature_indices: + B, C, T, H, W = original_shape + H = H // self.patch_spatial + W = W // self.patch_spatial + T = T // self.patch_temporal + if self.sequence_parallel: + x_feat = gather_along_first_dim(x, parallel_state.get_tensor_model_parallel_group()) + x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) + else: + x_feat = x + if self.blocks["block0"].x_format == "THWBD": + x_B_T_H_W_D = rearrange(x_feat, "T H W B D -> B T H W D", T=T, H=H, W=W) + elif self.blocks["block0"].x_format == "BTHWD": + x_B_T_H_W_D = x_feat + else: + raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") + + features.append(x_B_T_H_W_D) + + if x_ctrl is not None and name in x_ctrl: + x = x + x_ctrl[name] + # If we have all of the features, we can exit early + if return_features_early and len(features) == len(feature_indices): + return features + + if self.blocks["block0"].x_format == "THWBD": + x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") + elif self.blocks["block0"].x_format == "BTHWD": + x_B_T_H_W_D = x + else: + raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") + + x_B_D_T_H_W = self.decoder_head( + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_D=affline_emb_B_D, + crossattn_emb=None, + origin_shape=original_shape, + crossattn_mask=None, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + if len(feature_indices) == 0: + # no features requested, return only the model output + return x_B_D_T_H_W + else: + # score and features; score, features + return x_B_D_T_H_W, features + + def forward_blocks_memory_save( + self, + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ): + x_before_gate = 0 + x_skip = rearrange(x, "T H W B D -> (T H W) B D") + assert self.blocks["block0"].x_format == "THWBD" + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_per_block_pos_emb = rearrange(extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "T H W B D -> (T H W) B D") + else: + extra_per_block_pos_emb = None + gate_L_B_D = 1.0 + + features = [] + for name, block in self.blocks.items(): + gate_L_B_D, x_before_gate, x_skip = block( + x_before_gate, + x_skip, + gate_L_B_D, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_per_block_pos_emb, + ) + + # Extract features. + # Convert the block index in the memory save mode to the block index in the regular mode. + block_idx = int(name.split("block")[-1]) - 1 + if block_idx in feature_indices: + B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape + H = H_before_patchify // self.patch_spatial + W = W_before_patchify // self.patch_spatial + T = T_before_patchify // self.patch_temporal + if self.sequence_parallel: + x_feat = gather_along_first_dim(x_skip, parallel_state.get_tensor_model_parallel_group()) + x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) + else: + x_feat = x_skip + x_B_T_H_W_D = rearrange(x_feat, "(T H W) B D -> B T H W D", T=T, H=H, W=W) + + features.append(x_B_T_H_W_D) + + new_name = f"block{block_idx}" + if x_ctrl is not None and new_name in x_ctrl: + x_ctrl_ = x_ctrl[new_name] + x_ctrl_ = rearrange(x_ctrl_, "T H W B D -> (T H W) B D") + x_skip = x_skip + x_ctrl_ + # If we have all of the features, we can exit early + if return_features_early and len(features) == len(feature_indices): + return features + + x_THW_B_D_before_gate = x_before_gate + x_THW_B_D_skip = x_skip + + B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape + x_BT_HW_D_before_gate = rearrange( + x_THW_B_D_before_gate, + "(T H W) B D -> (B T) (H W) D", + T=T_before_patchify // self.patch_temporal, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + ) + x_BT_HW_D_skip = rearrange( + x_THW_B_D_skip, + "(T H W) B D -> (B T) (H W) D", + T=T_before_patchify // self.patch_temporal, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + ) + + x_BT_HW_D = self.final_layer.forward_with_memory_save( + x_BT_HW_D_before_gate=x_BT_HW_D_before_gate, + x_BT_HW_D_skip=x_BT_HW_D_skip, + gate_L_B_D=gate_L_B_D, + emb_B_D=affline_emb_B_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + if len(feature_indices) == 0: + # no features requested, return only the model output + return x_B_D_T_H_W + else: + # score and features; score, features + return x_B_D_T_H_W, features + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + x_ctrl: Optional[dict] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + feature_indices: Optional[Container[int]] = None, + return_features_early: bool = False, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + feature_indices: A set of feature indices (a set of integers) decides which blocks + to extract features from. If the set is non-empty, then features will be returned. + By default, feature_indices=None means extract no features. + return_features_early: If true, the forward pass returns the features once the set is complete. + This means the forward pass will not finish completely and no final output is returned. + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + we need forward_before_blocks pass to the forward_before_blocks function. + """ + if feature_indices is None: + feature_indices = {} + if return_features_early and len(feature_indices) == 0: + # Exit immediately if user requested this. + return [] + + inputs = self.forward_before_blocks( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( + inputs["x"], + inputs["affline_emb_B_D"], + inputs["crossattn_emb"], + inputs["crossattn_mask"], + inputs["rope_emb_L_1_1_D"], + inputs["adaln_lora_B_3D"], + inputs["original_shape"], + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + + if self.use_memory_save: + return self.forward_blocks_memory_save( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ) + + return self.forward_blocks_regular( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ) + + @property + def fsdp_wrap_block_cls(self): + return DITBuildingBlock + + def enable_context_parallel(self, cp_group: ProcessGroup): + cp_ranks = get_process_group_ranks(cp_group) + cp_size = len(cp_ranks) + # Set these attributes for spliting the data after embedding. + self.cp_group = cp_group + # Set these attributes for computing the loss. + self.cp_size = cp_size + + self.pos_embedder.enable_context_parallel(cp_group) + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.enable_context_parallel(cp_group) + # Loop through the model to set up context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff"]: + continue + elif layer.block_type in ["cross_attn", "ca"]: + continue + else: + layer.block.attn.attn_op.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) + + log.debug(f"[CP] Enable context parallelism with size {cp_size}") + + def disable_context_parallel(self): + self.cp_group = None + self.cp_size = None + + self.pos_embedder.disable_context_parallel() + if self.extra_per_block_abs_pos_emb: + self.extra_pos_embedder.disable_context_parallel() + + # Loop through the model to disable context parallel. + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["mlp", "ff"]: + continue + elif layer.block_type in ["cross_attn", "ca"]: + continue + else: + layer.block.attn.attn_op.cp_group = None + layer.block.attn.attn_op.cp_ranks = None + layer.block.attn.attn_op.cp_stream = None + + log.debug("[CP] Disable context parallelism.") + + def enable_sequence_parallel(self): + self._set_sequence_parallel(True) + + def disable_sequence_parallel(self): + self._set_sequence_parallel(False) + + def _set_sequence_parallel(self, status: bool): + self.sequence_parallel = status + self.final_layer.sequence_parallel = status + for block in self.blocks.values(): + for layer in block.blocks: + if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: + layer.block.attn.to_q[0].sequence_parallel = status + layer.block.attn.to_k[0].sequence_parallel = status + layer.block.attn.to_v[0].sequence_parallel = status + layer.block.attn.to_out[0].sequence_parallel = status + layer.block.attn.attn_op.sequence_parallel = status + elif layer.block_type in ["mlp", "ff"]: + layer.block.layer1.sequence_parallel = status + layer.block.layer2.sequence_parallel = status + else: + raise ValueError(f"Unknown block type {layer.block_type}") + + @property + def is_context_parallel_enabled(self): + return self.cp_group is not None diff --git a/cosmos_predict1/diffusion/training/networks/general_dit_action.py b/cosmos_predict1/diffusion/training/networks/general_dit_action.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd4db422844da3b422512096e58836c93117e33 --- /dev/null +++ b/cosmos_predict1/diffusion/training/networks/general_dit_action.py @@ -0,0 +1,515 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. +It allows us easy to switch building blocks used and their order. Its instantiation includes +* transformer on fully flattened tokens +* factored spatial and temporal attention +* factored non-overlap spatial and temporal attention +* mixing of above attention types + +Limitations: + +* In favor of simplicity and cleanness, many ops are not fused and we can do better +* such as combining mutiple adaln MLPs into one inside one transformer block. +* we use reshape heavily, which may be not efficient when its occurs unnecessary CUDA memory copy + +Purpose: +* A prototype for testing different attention types and their combinations +* Idealy, we want to know where we should allocate our resources / FLOPS / memory via extensive empirical studies +""" + +from collections.abc import Container +from typing import List, Optional, Tuple + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import nn + +from cosmos_predict1.diffusion.module.timm import Mlp +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp +from cosmos_predict1.diffusion.training.networks.general_dit import GeneralDIT +from cosmos_predict1.diffusion.training.tensor_parallel import scatter_along_first_dim +from cosmos_predict1.utils import log + + +class ActionConditionalGeneralDIT(GeneralDIT): + """ + ActionConditionalGeneralDIT is a subclass of GeneralDIT that take `action` as condition. + Action embedding is would be added to timestep embedding. + """ + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + action: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + action: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + x_ctrl: Optional[dict] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + feature_indices: Optional[Container[int]] = None, + return_features_early: bool = False, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + feature_indices: A set of feature indices (a set of integers) decides which blocks + to extract features from. If the set is non-empty, then features will be returned. + By default, feature_indices=None means extract no features. + return_features_early: If true, the forward pass returns the features once the set is complete. + This means the forward pass will not finish completely and no final output is returned. + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + we need forward_before_blocks pass to the forward_before_blocks function. + """ + if feature_indices is None: + feature_indices = {} + if return_features_early and len(feature_indices) == 0: + # Exit immediately if user requested this. + return [] + + inputs = self.forward_before_blocks( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + action=action, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( + inputs["x"], + inputs["affline_emb_B_D"], + inputs["crossattn_emb"], + inputs["crossattn_mask"], + inputs["rope_emb_L_1_1_D"], + inputs["adaln_lora_B_3D"], + inputs["original_shape"], + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + + if self.use_memory_save: + return self.forward_blocks_memory_save( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ) + + return self.forward_blocks_regular( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D, + adaln_lora_B_3D, + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + feature_indices, + original_shape, + x_ctrl, + return_features_early, + ) + + +class ActionConditionalVideoExtendGeneralDIT(ActionConditionalGeneralDIT): + """ + ActionConditionalVideoExtendGeneralDIT is a subclass of ActionConditionalGeneralDIT that take `action` as condition. + Action embedding is would be added to timestep embedding. + """ + + def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): + self.add_augment_sigma_embedding = add_augment_sigma_embedding + + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels}") + + assert hasattr(self, "model_channels"), "model_channels attribute is missing" + self.action_embedder_B_D = Mlp( + in_features=7, + hidden_features=self.model_channels * 4, + out_features=self.model_channels, + act_layer=lambda: nn.GELU(approximate="tanh"), + drop=0, + ) + self.action_embedder_B_3D = Mlp( + in_features=7, + hidden_features=self.model_channels * 4, + out_features=self.model_channels * 3, + act_layer=lambda: nn.GELU(approximate="tanh"), + drop=0, + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + action: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=2, cp_group=self.cp_group + ) + condition_video_indicator = split_inputs_cp( + condition_video_indicator, seq_dim=2, cp_group=self.cp_group + ) + if condition_video_pose is not None: + condition_video_pose = split_inputs_cp(condition_video_pose, seq_dim=2, cp_group=self.cp_group) + # log.critical(f"hit video case, video_cond_bool: {video_cond_bool}, condition_video_indicator: {condition_video_indicator.flatten()}, condition_video_input_mask: {condition_video_input_mask.shape}, {condition_video_input_mask[:,:,:,0,0]}", rank0_only=False) + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + if data_type == DataType.IMAGE: + # For image, we dont have condition_video_input_mask, or condition_video_pose + # We need to add the extra channel for video condition mask + padding_channels = self.in_channels - x.shape[1] + x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) + else: + assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + action=action, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + action: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + + condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + # Add action conditioning + assert action is not None, "Action is required for action-conditional training" + if action is not None: + action = action[:, 0, :] # Since we are now training on 1 frame, we only need the first frame action. + action_embedding_B_D = self.action_embedder_B_D(action) + action_embedding_B_3D = self.action_embedder_B_3D(action) + timesteps_B_D = timesteps_B_D + action_embedding_B_D + adaln_lora_B_3D = adaln_lora_B_3D + action_embedding_B_3D + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + if self.add_augment_sigma_embedding: + if condition_video_augment_sigma is None: + # Handling image case + # Note: for video case, when there is not condition frames, we also set it as zero, see extend_model augment_conditional_latent_frames function + assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" + condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) + + affline_augment_sigma_emb_B_D, adaln_lora_sigma_emb_B_3D = self.augment_sigma_embedder( + condition_video_augment_sigma.flatten() + ) + affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output diff --git a/cosmos_predict1/diffusion/training/networks/general_dit_lvg.py b/cosmos_predict1/diffusion/training/networks/general_dit_lvg.py new file mode 100644 index 0000000000000000000000000000000000000000..3d344e9b49dbab908611354d8c9f76d4d006bc36 --- /dev/null +++ b/cosmos_predict1/diffusion/training/networks/general_dit_lvg.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from einops import rearrange +from megatron.core import parallel_state +from torch import nn + +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp +from cosmos_predict1.diffusion.training.module.blocks import SDXLTimestepEmbedding, SDXLTimesteps +from cosmos_predict1.diffusion.training.networks.general_dit import GeneralDIT +from cosmos_predict1.diffusion.training.tensor_parallel import scatter_along_first_dim +from cosmos_predict1.utils import log + + +class VideoExtendGeneralDIT(GeneralDIT): + def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): + self.add_augment_sigma_embedding = add_augment_sigma_embedding + + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels}") + + def build_additional_timestamp_embedder(self): + super().build_additional_timestamp_embedder() + if self.add_augment_sigma_embedding: + log.info("Adding augment sigma embedding") + self.augment_sigma_embedder = nn.Sequential( + SDXLTimesteps(self.model_channels), + SDXLTimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), + ) + + def init_weights(self): + if self.add_augment_sigma_embedding: + # Initialize timestep embedding for augment sigma + nn.init.normal_(self.augment_sigma_embedder[1].linear_1.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_1.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_1.bias, 0) + nn.init.normal_(self.augment_sigma_embedder[1].linear_2.weight, std=0.02) + if self.augment_sigma_embedder[1].linear_2.bias is not None: + nn.init.constant_(self.augment_sigma_embedder[1].linear_2.bias, 0) + + super().init_weights() # Call this last since it wil call TP weight init + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=2, cp_group=self.cp_group + ) + condition_video_indicator = split_inputs_cp( + condition_video_indicator, seq_dim=2, cp_group=self.cp_group + ) + if condition_video_pose is not None: + condition_video_pose = split_inputs_cp(condition_video_pose, seq_dim=2, cp_group=self.cp_group) + # log.critical(f"hit video case, video_cond_bool: {video_cond_bool}, condition_video_indicator: {condition_video_indicator.flatten()}, condition_video_input_mask: {condition_video_input_mask.shape}, {condition_video_input_mask[:,:,:,0,0]}", rank0_only=False) + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + if data_type == DataType.IMAGE: + # For image, we dont have condition_video_input_mask, or condition_video_pose + # We need to add the extra channel for video condition mask + padding_channels = self.in_channels - x.shape[1] + x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) + else: + assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + + condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + if self.add_augment_sigma_embedding: + if condition_video_augment_sigma is None: + # Handling image case + # Note: for video case, when there is not condition frames, we also set it as zero, see extend_model augment_conditional_latent_frames function + assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" + condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) + + affline_augment_sigma_emb_B_D, adaln_lora_sigma_emb_B_3D = self.augment_sigma_embedder( + condition_video_augment_sigma.flatten() + ) + affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output diff --git a/cosmos_predict1/diffusion/training/networks/general_dit_lvg_multiview.py b/cosmos_predict1/diffusion/training/networks/general_dit_lvg_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..ac2641d62a41cedd22ffab3949f531ba743adcae --- /dev/null +++ b/cosmos_predict1/diffusion/training/networks/general_dit_lvg_multiview.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp +from cosmos_predict1.diffusion.training.networks.general_dit_multiview import MultiviewGeneralDIT +from cosmos_predict1.utils import log + + +class VideoExtendMultiviewGeneralDIT(MultiviewGeneralDIT): + def __init__(self, *args, in_channels, **kwargs): + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels + 1, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels + 1}") + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views + ) + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=3, cp_group=self.cp_group + ) + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views + ) + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) diff --git a/cosmos_predict1/diffusion/training/networks/general_dit_multiview.py b/cosmos_predict1/diffusion/training/networks/general_dit_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..62b16e37a50ec959373f91970bb11209e983710a --- /dev/null +++ b/cosmos_predict1/diffusion/training/networks/general_dit_multiview.py @@ -0,0 +1,460 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from cosmos_predict1.diffusion.training.conditioner import DataType +from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp +from cosmos_predict1.diffusion.training.module.blocks import GeneralDITTransformerBlock, PatchEmbed +from cosmos_predict1.diffusion.training.module.position_embedding import ( + MultiviewSinCosPosEmbAxis, + MultiviewVideoRopePosition3DEmb, +) +from cosmos_predict1.diffusion.training.networks.general_dit import GeneralDIT +from cosmos_predict1.utils import log + + +class MultiviewGeneralDIT(GeneralDIT): + def __init__( + self, + *args, + n_views: int = 3, + view_condition_dim: int = 3, + traj_condition_dim: int = 0, + concat_view_embedding: bool = True, + concat_traj_embedding: bool = False, + add_repeat_frame_embedding: bool = False, + **kwargs, + ): + self.n_views = n_views + self.view_condition_dim = view_condition_dim + self.concat_view_embedding = concat_view_embedding + self.traj_condition_dim = traj_condition_dim + self.concat_traj_embedding = concat_traj_embedding + self.add_repeat_frame_embedding = add_repeat_frame_embedding + + super().__init__(*args, **kwargs) + # reinit self.blocks + del self.blocks + self.blocks = nn.ModuleDict() + + layer_mask = [False] * self.num_blocks if kwargs["layer_mask"] is None else kwargs["layer_mask"] + assert ( + len(layer_mask) == self.num_blocks + ), f"Layer mask length {len(layer_mask)} does not match num_blocks {self.num_blocks}" + for idx in range(self.num_blocks): + if layer_mask[idx]: + continue + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=self.model_channels, + context_dim=kwargs["crossattn_emb_channels"], + num_heads=self.num_heads, + block_config=self.block_config, + window_sizes=( + kwargs["window_sizes"] if idx in kwargs["window_block_indexes"] else [] + ), # There will be bug if using "WA-CA-MLP" + mlp_ratio=kwargs["mlp_ratio"], + spatial_attn_win_size=kwargs["spatial_attn_win_size"], + temporal_attn_win_size=kwargs["temporal_attn_win_size"], + x_format=self.block_x_format, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + n_views=self.n_views, + ) + self.view_embeddings = nn.Embedding(n_views, view_condition_dim) # Learnable embedding layer + + if self.concat_traj_embedding: + self.traj_embeddings = nn.Linear(192, self.traj_condition_dim) # Learnable embedding layer + if self.add_repeat_frame_embedding: + self.repeat_frame_embedding = nn.Linear(1, view_condition_dim) # Learnable embedding layer + + self.init_weights() + + def build_patch_embed(self): + ( + concat_padding_mask, + in_channels, + patch_spatial, + patch_temporal, + model_channels, + view_condition_dim, + traj_condition_dim, + ) = ( + self.concat_padding_mask, + self.in_channels, + self.patch_spatial, + self.patch_temporal, + self.model_channels, + self.view_condition_dim, + self.traj_condition_dim, + ) + if self.concat_view_embedding: + in_channels = in_channels + view_condition_dim if view_condition_dim > 0 else in_channels + + if self.concat_traj_embedding: + in_channels = in_channels + traj_condition_dim if traj_condition_dim > 0 else in_channels + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + keep_spatio=True, + legacy_patch_emb=self.legacy_patch_emb, + ) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + if self.legacy_patch_emb: + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def build_pos_embed(self): + if self.pos_emb_cls == "rope3d": + cls_type = MultiviewVideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + n_views=self.n_views, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + assert self.extra_per_block_abs_pos_emb is True, "extra_per_block_abs_pos_emb must be True" + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "sincos", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + self.extra_pos_embedder = MultiviewSinCosPosEmbAxis(**kwargs) + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + trajectory = kwargs.get("trajectory", None) + frame_repeat = kwargs.get("frame_repeat", None) + + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + trajectory=trajectory, + frame_repeat=frame_repeat, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) + + if self.additional_timestamp_channels: + additional_cond_B_D = self.prepare_additional_timestamp_embedder( + bs=x.shape[0], + fps=fps, + h=image_size[:, 0], + w=image_size[:, 1], + org_h=image_size[:, 2], + org_w=image_size[:, 3], + ) + + affline_emb_B_D += additional_cond_B_D + affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + # for logging purpose + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = affline_emb_B_D + self.crossattn_emb = crossattn_emb + self.crossattn_mask = crossattn_mask + + if self.use_cross_attn_mask: + crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + if self.sequence_parallel: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Sequence parallel requires the input tensor to be scattered along the first dimension. + assert self.block_config == "FA-CA-MLP" # Only support this block config for now + T, H, W, B, D = x.shape + # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer + x = x.view(T * H * W, 1, 1, B, D) + assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 + x = scatter_along_first_dim(x, tp_group) + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( + T * H * W, 1, 1, B, D + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group + ) + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + trajectory: Optional[torch.Tensor] = None, + frame_repeat: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` + with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + + view_indices = torch.arange(self.n_views).to(x_B_C_T_H_W.device) # View indices [0, 1, ..., V-1] + view_embedding = self.view_embeddings(view_indices) # Shape: [V, embedding_dim] + view_embedding = rearrange(view_embedding, "V D -> D V") + view_embedding = view_embedding.unsqueeze(0).unsqueeze(3).unsqueeze(4).unsqueeze(5) # Shape: [1, D, V, 1, 1, 1] + + if self.add_repeat_frame_embedding: + if frame_repeat is None: + frame_repeat = ( + torch.zeros([x_B_C_T_H_W.shape[0], view_embedding.shape[1]]) + .to(view_embedding.device) + .to(view_embedding.dtype) + ) + frame_repeat_embedding = self.repeat_frame_embedding(frame_repeat.unsqueeze(-1)) + frame_repeat_embedding = rearrange(frame_repeat_embedding, "B V D -> B D V") + view_embedding = view_embedding + frame_repeat_embedding.unsqueeze(3).unsqueeze(4).unsqueeze(5) + + x_B_C_V_T_H_W = rearrange(x_B_C_T_H_W, "B C (V T) H W -> B C V T H W", V=self.n_views) + view_embedding = view_embedding.expand( + x_B_C_V_T_H_W.shape[0], + view_embedding.shape[1], + view_embedding.shape[2], + x_B_C_V_T_H_W.shape[3], + x_B_C_V_T_H_W.shape[4], + x_B_C_V_T_H_W.shape[5], + ) # Shape: [B, V, 3, t, H, W] + if self.concat_traj_embedding: + traj_emb = self.traj_embeddings(trajectory) + traj_emb = traj_emb.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5) + traj_emb = traj_emb.expand( + x_B_C_V_T_H_W.shape[0], + traj_emb.shape[1], + view_embedding.shape[2], + x_B_C_V_T_H_W.shape[3], + x_B_C_V_T_H_W.shape[4], + x_B_C_V_T_H_W.shape[5], + ) # Shape: [B, V, 3, t, H, W] + + x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding, traj_emb], dim=1) + else: + x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding], dim=1) + + x_B_C_T_H_W = rearrange(x_B_C_V_T_H_W, " B C V T H W -> B C (V T) H W", V=self.n_views) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] + return x_B_T_H_W_D, None, extra_pos_emb + + +class VideoExtendGeneralDIT(MultiviewGeneralDIT): + def __init__(self, *args, in_channels, **kwargs): + # extra channel for video condition mask + super().__init__(*args, in_channels=in_channels + 1, **kwargs) + log.info(f"VideoExtendGeneralDIT in_channels: {in_channels + 1}") + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + video_cond_bool: Optional[torch.Tensor] = None, + condition_video_indicator: Optional[torch.Tensor] = None, + condition_video_input_mask: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + condition_video_pose: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Args: + condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation + condition_video_pose: (B, 1, T, H, W) tensor of pose condition + """ + B, C, T, H, W = x.shape + + if data_type == DataType.VIDEO: + assert ( + condition_video_input_mask is not None + ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" + if self.cp_group is not None: + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views + ) + condition_video_input_mask = split_inputs_cp( + condition_video_input_mask, seq_dim=3, cp_group=self.cp_group + ) + condition_video_input_mask = rearrange( + condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views + ) + input_list = [x, condition_video_input_mask] + if condition_video_pose is not None: + if condition_video_pose.shape[2] > T: + log.warning( + f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" + ) + condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() + input_list.append(condition_video_pose) + x = torch.cat( + input_list, + dim=1, + ) + + return super().forward( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) diff --git a/cosmos_predict1/diffusion/training/tensor_parallel.py b/cosmos_predict1/diffusion/training/tensor_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..c756c38e53d2c71f3ab3fa2b08859fdf1bb96bc5 --- /dev/null +++ b/cosmos_predict1/diffusion/training/tensor_parallel.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +import torch.distributed as dist +from torch.autograd import Function + + +class AllGather(Function): + @staticmethod + def forward(ctx, tensor, process_group): + world_size = dist.get_world_size(process_group) + ctx.world_size = world_size + ctx.rank = process_group.rank() + + gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] + dist.all_gather(gathered_tensors, tensor.contiguous(), process_group) + return torch.cat(gathered_tensors, dim=0) + + @staticmethod + def backward(ctx, grad_output): + world_size = ctx.world_size + rank = ctx.rank + + # Split the gradient tensor + grad_chunks = grad_output.chunk(world_size) + + # Select the gradient chunk for the current rank + grad_input = grad_chunks[rank] + return grad_input, None + + +def gather_along_first_dim(tensor, process_group): + return AllGather.apply(tensor, process_group) + + +class Scatter(Function): + @staticmethod + def forward(ctx, tensor, process_group): + world_size = dist.get_world_size(process_group) + ctx.world_size = world_size + ctx.process_group = process_group + rank = process_group.rank() + + # Split the tensor + tensor_chunks = tensor.chunk(world_size) + + # Select the tensor chunk for the current rank + return tensor_chunks[rank] + + @staticmethod + def backward(ctx, grad_output): + world_size = ctx.world_size + process_group = ctx.process_group + + # Gather the gradient tensor + gathered_grads = [torch.zeros_like(grad_output) for _ in range(world_size)] + dist.all_gather(gathered_grads, grad_output.contiguous(), process_group) + return torch.cat(gathered_grads, dim=0), None + + +def scatter_along_first_dim(tensor, process_group): + return Scatter.apply(tensor, process_group) + + +if __name__ == "__main__": + # Torch global setup for distributed training + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group(world_size=world_size, rank=rank) + + # Create a tensor with gradients + x = torch.randn(10, 1, requires_grad=True, device="cuda") + + # Perform all_gather with gradient support + y = gather_along_first_dim(x, dist.group.WORLD) + print(f"{y.shape=}") + y = scatter_along_first_dim(y, dist.group.WORLD) + print(f"{y.shape=}") + + # Use the result in your computation + loss = y.sum() + loss.backward() + + # x.grad now contains the gradients + print(x.grad) diff --git a/cosmos_predict1/diffusion/training/train.py b/cosmos_predict1/diffusion/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a65eece39e5cc993918dd53be44d87b8c70d28a0 --- /dev/null +++ b/cosmos_predict1/diffusion/training/train.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import os + +import torch.distributed as dist +from loguru import logger as logging +from omegaconf import OmegaConf + +from cosmos_predict1.diffusion.config.config import Config +from cosmos_predict1.utils import log, misc +from cosmos_predict1.utils.config_helper import get_config_module, override +from cosmos_predict1.utils.lazy_config import instantiate +from cosmos_predict1.utils.lazy_config.lazy import LazyConfig +from cosmos_predict1.utils.parallel_state_helper import is_tp_cp_pp_rank0 + + +@misc.timer("instantiate model") +def instantiate_model(config: Config, trainer) -> None: + misc.set_random_seed(seed=config.trainer.seed, by_rank=False) + config.model_obj.config = config.model + if getattr(config.model, "fsdp_enabled", False): + assert config.trainer.distributed_parallelism == "fsdp", "FSDP model is only supported with FSDP trainer" + log.critical("FSDP enabled") + config.model_obj.fsdp_checkpointer = trainer.checkpointer + model = instantiate(config.model_obj) + config.model_obj.fsdp_checkpointer = None + else: + model = instantiate(config.model_obj) + config.model_obj.config = None + misc.set_random_seed(seed=config.trainer.seed, by_rank=True) + return model + + +def destroy_distributed(): + log.info("Destroying distributed environment...") + if dist.is_available() and dist.is_initialized(): + try: + dist.destroy_process_group() + except ValueError as e: + print(f"Error destroying default process group: {e}") + + +@logging.catch(reraise=True) +def launch(config: Config, args: argparse.Namespace) -> None: + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + trainer = config.trainer.type(config) + # # Setup the miscellaneous stuff for reproducibility. + # log_reproducible_setup(config, args) + # Create the model + model = instantiate_model(config, trainer) + model.on_model_init_end() + # Create the dataloaders. + if args.mp0_only_dl: + log.critical( + "Using only tp_cp_pp_rank0 dataloader for faster dataloading! Make sure val dl is mock and mock data has same keys as real data." + ) + raise NotImplementedError( + "mp0_only_dl is not implemented correctly! Please revisit this code and propose a more robust impl that raise error timely! It does not do necessary check before training to confirm it can work with image / video data. Current impl is problematic for image training." + ) + if is_tp_cp_pp_rank0() or not args.mp0_only_dl: + dataloader_train = instantiate(config.dataloader_train) + else: + dataloader_train = instantiate(config.dataloader_val) + dataloader_val = instantiate(config.dataloader_val) + # Start training + trainer.train( + model, + dataloader_train, + dataloader_val, + ) + destroy_distributed() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Training") + parser.add_argument( + "--config", + default="cosmos_predict1/diffusion/posttrain/config/config.py", + help="Path to the config file", + ) + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--dryrun", + action="store_true", + help="Do a dry run without training. Useful for debugging the config.", + ) + parser.add_argument( + "--mp0_only_dl", + action="store_true", + help="Use only model parallel rank 0 dataloader for faster dataloading! Make sure mock data has same keys as real data.", + ) + args = parser.parse_args() + config_module = get_config_module(args.config) + config = importlib.import_module(config_module).make_config() + config = override(config, args.opts) + if args.dryrun: + os.makedirs(config.job.path_local, exist_ok=True) + LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") + print(OmegaConf.to_yaml(OmegaConf.load(f"{config.job.path_local}/config.yaml"))) + print(f"{config.job.path_local}/config.yaml") + else: + # Launch the training job. + launch(config, args) diff --git a/cosmos_predict1/diffusion/training/trainer.py b/cosmos_predict1/diffusion/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e68de41c528512404a4c43ade569faaace3caa08 --- /dev/null +++ b/cosmos_predict1/diffusion/training/trainer.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.diffusion.training.utils.checkpointer import MultiRankCheckpointer +from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer +from cosmos_predict1.utils.trainer import Trainer as BaseTrainer + + +class Trainer(BaseTrainer): + def __init__(self, config): + super(Trainer, self).__init__(config) + if config.trainer.distributed_parallelism == "ddp": + self.checkpointer = MultiRankCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) + elif config.trainer.distributed_parallelism == "fsdp": + self.checkpointer = FSDPCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) + else: + raise ValueError(f"Unsupported distributed parallelism: {config.trainer.distributed_parallelism}") diff --git a/cosmos_predict1/diffusion/training/utils/checkpointer.py b/cosmos_predict1/diffusion/training/utils/checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..069ed708b9247a1b53c7e06d34c159f4d6d9d2e1 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/checkpointer.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import threading +from typing import List, NamedTuple, Tuple + +import torch + +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.checkpointer import Checkpointer as BaseCheckpointer +from cosmos_predict1.utils.model import Model + +TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) +if TORCH_VERSION >= (1, 11): + from torch.ao import quantization + from torch.ao.quantization import FakeQuantizeBase, ObserverBase +elif ( + TORCH_VERSION >= (1, 8) + and hasattr(torch.quantization, "FakeQuantizeBase") + and hasattr(torch.quantization, "ObserverBase") +): + from torch import quantization + from torch.quantization import FakeQuantizeBase, ObserverBase + + +class _IncompatibleKeys( + NamedTuple( + "IncompatibleKeys", + [ + ("missing_keys", List[str]), + ("unexpected_keys", List[str]), + ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), + ], + ) +): + pass + + +class MultiRankCheckpointer(BaseCheckpointer): + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + # checkpoint_file = f"iter_{iteration:09}.pt" + postfix, _, total_ema_num = model.get_ckpt_postfix() + checkpoint_file = f"iter_{iteration:09}{postfix}.pt" + save_ranks = list(range(total_ema_num)) + for _rank in save_ranks: + if distributed.get_rank() == _rank: + state_dict = dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + scheduler=scheduler.state_dict(), + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + state_dict = misc.to(state_dict, device="cpu") + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + latest_checkpoint_file = self._read_latest_checkpoint_file() + if latest_checkpoint_file is not None: + # different from base checkpointer, this support multi-EMA + postfix, _, total_ema_num = model.get_ckpt_postfix() + latest_checkpoint_file = latest_checkpoint_file.replace(".pt", f"{postfix}.pt") + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_dir = self.checkpoint_dir_local + checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) + resume = True + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + # different from base checkpointer, this support multi-EMA + postfix, _, total_ema_num = model.get_ckpt_postfix() + checkpoint_path = checkpoint_path.replace(".pt", f"{postfix}.pt") + resume = self.load_training_state + else: + # 3. Randomly initialize the model parameters and train from scratch. + checkpoint_path = None + resume = False + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + self.callbacks.on_load_checkpoint(model, state_dict=state_dict) + # Load the state dicts. + log.info("- Loading the model...") + log.critical(model.load_state_dict(state_dict["model"], strict=self.strict_resume)) + if resume: + iteration = state_dict["iteration"] + assert optimizer and scheduler + log.info("- Loading the optimizer...") + optimizer.load_state_dict(state_dict["optimizer"]) + log.info("- Loading the scheduler...") + scheduler.load_state_dict(state_dict["scheduler"]) + scheduler.last_epoch = iteration + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(state_dict["grad_scaler"]) + log.success(f"Done with loading the checkpoint (iteration {iteration}).") + else: + iteration = 0 + log.success("Done with loading the checkpoint.") + else: + # Checkpoint not found and not specified. We will train everything from scratch. + iteration = 0 + log.info("Training from scratch.") + torch.cuda.empty_cache() + return iteration + + +# https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py +def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: + # workaround https://github.com/pytorch/pytorch/issues/24139 + model_state_dict = model.state_dict() + incorrect_shapes = [] + for k in list(checkpoint_state_dict.keys()): + if k in model_state_dict: + if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 + log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") + continue + model_param = model_state_dict[k] + # Allow mismatch for uninitialized parameters + if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): + continue + if not isinstance(model_param, torch.Tensor): + raise ValueError( + f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." + ) + + shape_model = tuple(model_param.shape) + shape_checkpoint = tuple(checkpoint_state_dict[k].shape) + if shape_model != shape_checkpoint: + has_observer_base_classes = ( + TORCH_VERSION >= (1, 8) + and hasattr(quantization, "ObserverBase") + and hasattr(quantization, "FakeQuantizeBase") + ) + if has_observer_base_classes: + # Handle the special case of quantization per channel observers, + # where buffer shape mismatches are expected. + def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: + # foo.bar.param_or_buffer_name -> [foo, bar] + key_parts = key.split(".")[:-1] + cur_module = model + for key_part in key_parts: + cur_module = getattr(cur_module, key_part) + return cur_module + + cls_to_skip = ( + ObserverBase, + FakeQuantizeBase, + ) + target_module = _get_module_for_key(model, k) + if isinstance(target_module, cls_to_skip): + # Do not remove modules with expected shape mismatches + # them from the state_dict loading. They have special logic + # in _load_from_state_dict to handle the mismatches. + continue + + incorrect_shapes.append((k, shape_checkpoint, shape_model)) + checkpoint_state_dict.pop(k) + incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) + # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling + missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] + unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] + return _IncompatibleKeys( + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + incorrect_shapes=incorrect_shapes, + ) diff --git a/cosmos_predict1/diffusion/training/utils/fsdp_helper.py b/cosmos_predict1/diffusion/training/utils/fsdp_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..ed4c1f6150e6d71c9fc626867ae11541e83c0134 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/fsdp_helper.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +from functools import partial + +import torch +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp._runtime_utils import ( + _post_forward, + _post_forward_reshard, + _pre_forward, + _pre_forward_unshard, + _root_pre_forward, +) +from torch.distributed.utils import _p_assert + +from cosmos_predict1.utils import distributed, log + + +def apply_fsdp_checkpointing(model, list_block_cls): + """apply activation checkpointing to model + returns None as model is updated directly + """ + log.critical("--> applying fdsp activation checkpointing...") + non_reentrant_wrapper = partial( + checkpoint_wrapper, + # offload_to_cpu=False, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + + def check_fn(submodule): + result = False + for block_cls in list_block_cls: + if isinstance(submodule, block_cls): + result = True + break + return result + + apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) + + +@contextmanager +def possible_fsdp_scope( + model: torch.nn.Module, +): + enabled = isinstance(model, FSDP) + if enabled: + assert not torch.is_grad_enabled(), "FSDP context should be entered with grad disabled" + handle = model._handle + args, kwargs = [0], dict(dummy=0) + with torch.autograd.profiler.record_function("FullyShardedDataParallel.possible_fsdp_scope"): + args, kwargs = _root_pre_forward(model, model, args, kwargs) + unused = None + args, kwargs = _pre_forward( + model, + handle, + _pre_forward_unshard, + model._fsdp_wrapped_module, + args, + kwargs, + ) + if handle: + _p_assert( + handle.flat_param.device == model.compute_device, + "Expected `FlatParameter` to be on the compute device " + f"{model.compute_device} but got {handle.flat_param.device}", + ) + try: + yield None + finally: + if enabled: + output = {"output": 1} + _post_forward(model, handle, _post_forward_reshard, model, unused, output) + + +def hsdp_device_mesh(replica_group_size=None, sharding_group_size=None, device=None): + """ + Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training. + + This function requires explicit sizes for replica and sharding groups to accommodate models + whose GPU fit is unknown, providing flexibility in distributed training setups. + + Args: + replica_group_size (int): The size of each replica group. Must be provided to ensure + the model fits within the available resources. + sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to + ensure the correct distribution of model parameters. + device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" + with the local rank as the device index. + + Returns: + A device mesh object compatible with FSDP. + + Raises: + ValueError: If replica_group_size or sharding_group_size are not provided, or if the + world size is not evenly divisible by the sharding group size. + RuntimeError: If a valid device mesh cannot be created. + + Usage: + If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then: + Sharding_Group_Size = 4 + Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups + >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size) + >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...) + """ + + # world_size = int(os.getenv("WORLD_SIZE", "1")) + world_size = distributed.get_world_size() + if sharding_group_size is None: + sharding_group_size = min(world_size, 8) + sharding_group_size = min(sharding_group_size, world_size) + if replica_group_size is None: + replica_group_size = world_size // sharding_group_size + + device = device or "cuda" + + if world_size % sharding_group_size != 0: + raise ValueError( + f"World size {world_size} is not evenly divisible by " f"sharding group size {sharding_group_size}." + ) + + if (world_size // sharding_group_size) % replica_group_size != 0: + raise ValueError( + f"The calculated number of replica groups is not evenly divisible by " + f"replica_group_size {replica_group_size}." + ) + + device_mesh = init_device_mesh( + device, (replica_group_size, sharding_group_size), mesh_dim_names=("replicate", "shard") + ) + if device_mesh is None: + raise RuntimeError("Failed to create a valid device mesh.") + + log.critical( + f"Device mesh initialized with replica group size {replica_group_size} and sharding group size {sharding_group_size}" + ) + + return device_mesh diff --git a/cosmos_predict1/diffusion/training/utils/inference_long_video.py b/cosmos_predict1/diffusion/training/utils/inference_long_video.py new file mode 100644 index 0000000000000000000000000000000000000000..8feb222a733666a710cbb98900e2055b3a231cc1 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/inference_long_video.py @@ -0,0 +1,527 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager +from typing import Tuple, Union + +import einops +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional as transforms_F +from matplotlib import pyplot as plt + +from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel +from cosmos_predict1.utils import log +from cosmos_predict1.utils.easy_io import easy_io + +"""This file contain functions needed for long video generation, +* function `generate_video_from_batch_with_loop` is used by `single_gpu_sep20` + +""" + + +@contextmanager +def switch_config_for_inference(model): + """For extend model inference, we need to make sure the condition_location is set to "first_n" and apply_corruption_to_condition_region is False. + This context manager changes the model configuration to the correct settings for inference, and then restores the original settings when exiting the context. + Args: + model (ExtendDiffusionModel): video generation model + """ + # Store the current condition_location + current_condition_location = model.config.conditioner.video_cond_bool.condition_location + if current_condition_location != "first_n" and current_condition_location != "first_and_last_1": + current_condition_location = "first_n" + current_apply_corruption_to_condition_region = ( + model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region + ) + try: + log.info( + "Change the condition_location to 'first_n' for inference, and apply_corruption_to_condition_region to False" + ) + # Change the condition_location to "first_n" for inference + model.config.conditioner.video_cond_bool.condition_location = current_condition_location + if current_apply_corruption_to_condition_region == "gaussian_blur": + model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "clean" + elif current_apply_corruption_to_condition_region == "noise_with_sigma": + model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "noise_with_sigma_fixed" + # Yield control back to the calling context + yield + finally: + # Restore the original condition_location after exiting the context + log.info( + f"Restore the original condition_location {current_condition_location}, apply_corruption_to_condition_region {current_apply_corruption_to_condition_region}" + ) + model.config.conditioner.video_cond_bool.condition_location = current_condition_location + model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = ( + current_apply_corruption_to_condition_region + ) + + +def visualize_latent_tensor_bcthw(tensor, nrow=1, show_norm=False, save_fig_path=None): + """Debug function to display a latent tensor as a grid of images. + Args: + tensor (torch.Tensor): tensor in shape BCTHW + nrow (int): number of images per row + show_norm (bool): whether to display the norm of the tensor + save_fig_path (str): path to save the visualization + + """ + log.info( + f"display latent tensor shape {tensor.shape}, max={tensor.max()}, min={tensor.min()}, mean={tensor.mean()}, std={tensor.std()}" + ) + tensor = tensor.float().cpu().detach() + tensor = einops.rearrange(tensor, "b c (t n) h w -> (b t h) (n w) c", n=nrow) # .numpy() + # display the grid + tensor_mean = tensor.mean(-1) + tensor_norm = tensor.norm(dim=-1) + log.info(f"tensor_norm, tensor_mean {tensor_norm.shape}, {tensor_mean.shape}") + plt.figure(figsize=(20, 20)) + plt.imshow(tensor_mean) + plt.title(f"mean {tensor_mean.mean()}, std {tensor_mean.std()}") + if save_fig_path: + os.makedirs(os.path.dirname(save_fig_path), exist_ok=True) + log.info(f"save to {os.path.abspath(save_fig_path)}") + plt.savefig(save_fig_path, bbox_inches="tight", pad_inches=0) + plt.show() + if show_norm: + plt.figure(figsize=(20, 20)) + plt.imshow(tensor_norm) + plt.show() + + +def visualize_tensor_bcthw(tensor: torch.Tensor, nrow=4, save_fig_path=None): + """Debug function to display a tensor as a grid of images. + Args: + tensor (torch.Tensor): tensor in shape BCTHW + nrow (int): number of images per row + save_fig_path (str): path to save the visualization + """ + log.info(f"display {tensor.shape}, {tensor.max()}, {tensor.min()}") + assert tensor.max() < 200, f"tensor max {tensor.max()} > 200, the data range is likely wrong" + tensor = tensor.float().cpu().detach() + tensor = einops.rearrange(tensor, "b c t h w -> (b t) c h w") + # use torchvision to save the tensor as a grid of images + grid = torchvision.utils.make_grid(tensor, nrow=nrow) + if save_fig_path is not None: + os.makedirs(os.path.dirname(save_fig_path), exist_ok=True) + log.info(f"save to {os.path.abspath(save_fig_path)}") + torchvision.utils.save_image(tensor, save_fig_path) + # display the grid + plt.figure(figsize=(20, 20)) + plt.imshow(grid.permute(1, 2, 0)) + plt.show() + + +def compute_num_frames_condition(model: "ExtendDiffusionModel", num_of_latent_overlap: int, downsample_factor=8) -> int: + """This function computes the number of condition pixel frames given the number of latent frames to overlap. + Args: + model (ExtendDiffusionModel): Video generation model + num_of_latent_overlap (int): Number of latent frames to overlap + downsample_factor (int): Downsample factor for temporal reduce + Returns: + int: Number of condition frames in output space + """ + # Access the VAE: use tokenizer.video_vae if it exists, otherwise use tokenizer directly + vae = model.tokenizer.video_vae if hasattr(model.tokenizer, "video_vae") else model.tokenizer + + # Check if the VAE is causal (default to True if attribute not found) + if getattr(vae, "is_casual", True): + # For causal model + num_frames_condition = num_of_latent_overlap // vae.latent_chunk_duration * vae.pixel_chunk_duration + if num_of_latent_overlap % vae.latent_chunk_duration == 1: + num_frames_condition += 1 + elif num_of_latent_overlap % vae.latent_chunk_duration > 1: + num_frames_condition += 1 + (num_of_latent_overlap % vae.latent_chunk_duration - 1) * downsample_factor + else: + num_frames_condition = num_of_latent_overlap * downsample_factor + + return num_frames_condition + + +def read_video_or_image_into_frames_BCTHW( + input_path: str, + input_path_format: str = None, + H: int = None, + W: int = None, + normalize: bool = True, + max_frames: int = -1, + also_return_fps: bool = False, +) -> torch.Tensor: + """Read video or image from file and convert it to tensor. The frames will be normalized to [-1, 1]. + Args: + input_path (str): path to the input video or image, end with .mp4 or .png or .jpg + H (int): height to resize the video + W (int): width to resize the video + Returns: + torch.Tensor: video tensor in shape (1, C, T, H, W), range [-1, 1] + """ + log.info(f"Reading video from {input_path}") + + loaded_data = easy_io.load(input_path, file_format=input_path_format, backend_args=None) + if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): + frames = np.array(loaded_data) # HWC, [0,255] + if frames.shape[-1] > 3: # RGBA, set the transparent to white + # Separate the RGB and Alpha channels + rgb_channels = frames[..., :3] + alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1] + + # Create a white background + white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB + + # Blend the RGB channels with the white background based on the alpha channel + frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( + np.uint8 + ) + frames = [frames] + fps = 0 + else: + frames, meta_data = loaded_data + fps = int(meta_data.get("fps")) + if max_frames != -1: + frames = frames[:max_frames] + input_tensor = np.stack(frames, axis=0) + input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") + if normalize: + input_tensor = input_tensor / 128.0 - 1.0 + input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW + log.info(f"Raw data shape: {input_tensor.shape}") + if H is not None and W is not None: + input_tensor = transforms_F.resize( + input_tensor, + size=(H, W), # type: ignore + interpolation=transforms_F.InterpolationMode.BICUBIC, + antialias=True, + ) + input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) + if normalize: + input_tensor = input_tensor.to("cuda") + log.info(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") + if also_return_fps: + return input_tensor, fps + return input_tensor + + +def create_condition_latent_from_input_frames( + model: ExtendDiffusionModel, + input_frames: torch.Tensor, + num_frames_condition: int = 25, +): + """Create condition latent for video generation. It will take the last num_frames_condition frames from the input frames as condition latent. + Args: + model (ExtendDiffusionModel): Video generation model + input_frames (torch.Tensor): Video tensor in shape (1,C,T,H,W), range [-1, 1] + num_frames_condition (int): Number of condition frames + Returns: + torch.Tensor: Condition latent in shape B,C,T,H,W + """ + B, C, T, H, W = input_frames.shape + # Dynamically access the VAE: use tokenizer.video_vae if it exists, otherwise use tokenizer directly + vae = model.tokenizer.video_vae if hasattr(model.tokenizer, "video_vae") else model.tokenizer + num_frames_encode = vae.pixel_chunk_duration # Access pixel_chunk_duration from the VAE + log.info( + f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" + ) + + log.info( + f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" + ) + + assert ( + input_frames.shape[2] >= num_frames_condition + ), f"input_frames not enough for condition, require at least {num_frames_condition}, got {input_frames.shape[2]}, {input_frames.shape}" + assert ( + num_frames_encode >= num_frames_condition + ), f"num_frames_encode should be larger than num_frames_condition, got {num_frames_encode}, {num_frames_condition}" + + # Put the conditional frames at the beginning of the video, and pad the end with zeros + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + condition_frames_first = input_frames[:, :, :num_frames_condition] + condition_frames_last = input_frames[:, :, -num_frames_condition:] + padding_frames = condition_frames_first.new_zeros(B, C, num_frames_encode + 1 - 2 * num_frames_condition, H, W) + encode_input_frames = torch.cat([condition_frames_first, padding_frames, condition_frames_last], dim=2) + else: + condition_frames = input_frames[:, :, -num_frames_condition:] + padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) + encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) + + log.info( + f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" + ) + if hasattr(model, "n_views"): + encode_input_frames = einops.rearrange(encode_input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views) + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + latent1 = model.encode(encode_input_frames[:, :, :num_frames_encode]) # BCTHW + latent2 = model.encode(encode_input_frames[:, :, num_frames_encode:]) + latent = torch.cat([latent1, latent2], dim=2) # BCTHW + else: + latent = model.encode(encode_input_frames) + return latent, encode_input_frames + + +def get_condition_latent( + model: ExtendDiffusionModel, + conditioned_image_or_video_path: str, + num_of_latent_condition: int = 4, + state_shape: list[int] = None, + input_path_format: str = None, + frame_index: int = 0, + frame_stride: int = 1, +): + if state_shape is None: + state_shape = model.state_shape + if num_of_latent_condition == 0: + log.info("No condition latent needed, return empty latent") + condition_latent = ( + torch.zeros( + [ + 1, + ] + + state_shape + ) + .to(torch.bfloat16) + .cuda() + ) + return condition_latent, None + + H, W = ( + state_shape[-2] * model.vae.spatial_compression_factor, + state_shape[-1] * model.vae.spatial_compression_factor, + ) + input_frames = read_video_or_image_into_frames_BCTHW( + conditioned_image_or_video_path, + input_path_format=input_path_format, + H=H, + W=W, + ) + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + start_frame = frame_index * frame_stride + end_frame = (frame_index + 1) * frame_stride + input_frames = torch.cat( + [input_frames[:, :, start_frame : start_frame + 1], input_frames[:, :, end_frame : end_frame + 1]], dim=2 + ).contiguous() # BCTHW + + num_frames_condition = compute_num_frames_condition( + model, num_of_latent_condition, downsample_factor=model.vae.temporal_compression_factor + ) + + condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_frames_condition) + condition_latent = condition_latent.to(torch.bfloat16) + return condition_latent, input_frames + + +def generate_video_from_batch_with_loop( + model: ExtendDiffusionModel, + state_shape: list[int], + is_negative_prompt: bool, + data_batch: dict, + condition_latent: torch.Tensor, + # hyper-parameters for inference + num_of_loops: int, + num_of_latent_overlap_list: list[int], + guidance: float, + num_steps: int, + seed: int, + add_input_frames_guidance: bool = False, + augment_sigma_list: list[float] = None, + data_batch_list: Union[None, list[dict]] = None, + visualize: bool = False, + save_fig_path: str = None, + skip_reencode: int = 0, + return_noise: bool = False, +) -> Tuple[np.array, list, list, torch.Tensor] | Tuple[np.array, list, list, torch.Tensor, torch.Tensor]: + """Generate video with loop, given data batch. The condition latent will be updated at each loop. + Args: + model (ExtendDiffusionModel) + state_shape (list): shape of the state tensor + is_negative_prompt (bool): whether to use negative prompt + + data_batch (dict): data batch for video generation + condition_latent (torch.Tensor): condition latent in shape BCTHW + + num_of_loops (int): number of loops to generate video + num_of_latent_overlap_list (list[int]): list number of latent frames to overlap between clips, different clips can have different overlap + guidance (float): The guidance scale to use during sample generation; defaults to 5.0. + num_steps (int): number of steps for diffusion sampling + seed (int): random seed for sampling + add_input_frames_guidance (bool): whether to add image guidance, default is False + augment_sigma_list (list): list of sigma value for the condition corruption at different clip, used when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed". default is None + + data_batch_list (list): list of data batch for video generation, used when num_of_loops >= 1, to support multiple prompts in auto-regressive generation. default is None + visualize (bool): whether to visualize the latent and grid, default is False + save_fig_path (str): path to save the visualization, default is None + + skip_reencode (int): whether to skip re-encode the input frames, default is 0 + return_noise (bool): whether to return the initial noise used for sampling, used for ODE pairs generation. Default is False + Returns: + np.array: generated video in shape THWC, range [0, 255] + list: list of condition latent, each in shape BCTHW + list: list of sample latent, each in shape BCTHW + torch.Tensor: initial noise used for sampling, shape BCTHW (if return_noise is True) + """ + + if data_batch_list is None: + data_batch_list = [data_batch for _ in range(num_of_loops)] + if visualize: + assert save_fig_path is not None, "save_fig_path should be set when visualize is True" + + # Generate video with loop + condition_latent_list = [] + decode_latent_list = [] # list collect the latent token to be decoded at the end + sample_latent = [] + grid_list = [] + + augment_sigma_list = ( + model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region_sigma_value + if augment_sigma_list is None + else augment_sigma_list + ) + + for i in range(num_of_loops): + num_of_latent_overlap_i = num_of_latent_overlap_list[i] + num_of_latent_overlap_i_plus_1 = ( + num_of_latent_overlap_list[i + 1] + if i + 1 < len(num_of_latent_overlap_list) + else num_of_latent_overlap_list[-1] + ) + if condition_latent.shape[2] < state_shape[1]: + # Padding condition latent to state shape + log.info(f"Padding condition latent {condition_latent.shape} to state shape {state_shape}") + b, c, t, h, w = condition_latent.shape + condition_latent = torch.cat( + [ + condition_latent, + condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), + ], + dim=2, + ).contiguous() + log.info(f"after padding, condition latent shape {condition_latent.shape}") + log.info(f"Generate video loop {i} / {num_of_loops}") + if visualize: + log.info(f"Visualize condition latent {i}") + visualize_latent_tensor_bcthw( + condition_latent[:, :, :4].float(), + nrow=4, + save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_condition_latent_first_4.png"), + ) # BCTHW + + condition_latent_list.append(condition_latent) + + if i < len(augment_sigma_list): + condition_video_augment_sigma_in_inference = augment_sigma_list[i] + log.info(f"condition_video_augment_sigma_in_inference {condition_video_augment_sigma_in_inference}") + else: + condition_video_augment_sigma_in_inference = augment_sigma_list[-1] + assert not add_input_frames_guidance, "add_input_frames_guidance should be False, not supported" + + sample = model.generate_samples_from_batch( + data_batch_list[i], + guidance=guidance, + state_shape=state_shape, + num_steps=num_steps, + is_negative_prompt=is_negative_prompt, + seed=seed + i, + condition_latent=condition_latent, + num_condition_t=num_of_latent_overlap_i, + condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, + return_noise=return_noise, + ) + + if return_noise: + sample, noise = sample + + if visualize: + log.info(f"Visualize sampled latent {i} 4-8 frames") + visualize_latent_tensor_bcthw( + sample[:, :, 4:8].float(), + nrow=4, + save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_sample_latent_last_4.png"), + ) # BCTHW + + diff_between_sample_and_condition = (sample - condition_latent)[:, :, :num_of_latent_overlap_i] + log.info( + f"Visualize diff between sample and condition latent {i} first 4 frames {diff_between_sample_and_condition.mean()}" + ) + + sample_latent.append(sample) + T = condition_latent.shape[2] + assert num_of_latent_overlap_i <= T, f"num_of_latent_overlap should be < T, get {num_of_latent_overlap_i}, {T}" + + if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: + assert skip_reencode, "skip_reencode should be turned on when sample_tokens_start_from_p_or_i is True" + if i == 0: + decode_latent_list.append(sample) + else: + decode_latent_list.append(sample[:, :, num_of_latent_overlap_i:]) + else: + # Interpolator mode. Decode the first and last as an image. + if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": + grid_BCTHW_1 = (1.0 + model.decode(sample[:, :, :-1, ...])).clamp(0, 2) / 2 # [B, 3, T-1, H, W], [0, 1] + grid_BCTHW_2 = (1.0 + model.decode(sample[:, :, -1:, ...])).clamp(0, 2) / 2 # [B, 3, 1, H, W], [0, 1] + grid_BCTHW = torch.cat([grid_BCTHW_1, grid_BCTHW_2], dim=2) # [B, 3, T, H, W], [0, 1] + else: + grid_BCTHW = (1.0 + model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W], [0, 1] + + if visualize: + log.info(f"Visualize grid {i}") + visualize_tensor_bcthw( + grid_BCTHW.float(), nrow=5, save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_grid.png") + ) + grid_np_THWC = ( + (grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) + ) # THW3, range [0, 255] + + # Post-process the output: cut the conditional frames from the output if it's not the first loop + num_cond_frames = compute_num_frames_condition( + model, num_of_latent_overlap_i_plus_1, downsample_factor=model.tokenizer.temporal_compression_factor + ) + if i == 0: + new_grid_np_THWC = grid_np_THWC # First output, dont cut the conditional frames + else: + new_grid_np_THWC = grid_np_THWC[ + num_cond_frames: + ] # Remove the conditional frames from the output, since it's overlapped with previous loop + grid_list.append(new_grid_np_THWC) + + # Prepare the next loop: re-compute the condition latent + if hasattr(model, "n_views"): + grid_BCTHW = einops.rearrange(grid_BCTHW, "B C (V T) H W -> (B V) C T H W", V=model.n_views) + condition_frame_input = grid_BCTHW[:, :, -num_cond_frames:] * 2 - 1 # BCTHW, range [0, 1] to [-1, 1] + if skip_reencode: + # Use the last num_of_latent_overlap latent token as condition latent + log.info(f"Skip re-encode the condition frames, use the last {num_of_latent_overlap_i_plus_1} latent token") + condition_latent = sample[:, :, -num_of_latent_overlap_i_plus_1:] + else: + # Re-encode the condition frames to get the new condition latent + condition_latent, _ = create_condition_latent_from_input_frames( + model, condition_frame_input, num_frames_condition=num_cond_frames + ) # BCTHW + condition_latent = condition_latent.to(torch.bfloat16) + + # save videos + if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: + # decode all video together + decode_latent_list = torch.cat(decode_latent_list, dim=2) + grid_BCTHW = (1.0 + model.decode(decode_latent_list)).clamp(0, 2) / 2 # [B, 3, T, H, W], [0, 1] + video_THWC = ( + (grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) + ) # THW3, range [0, 255] + else: + video_THWC = np.concatenate(grid_list, axis=0) # THW3, range [0, 255] + + if return_noise: + return video_THWC, condition_latent_list, sample_latent, noise + return video_THWC, condition_latent_list, sample_latent diff --git a/cosmos_predict1/diffusion/training/utils/layer_control/peft_control_config_parser.py b/cosmos_predict1/diffusion/training/utils/layer_control/peft_control_config_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..1d38076f2ecd368bcf26b992f71b4a5c5bb9fe81 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/layer_control/peft_control_config_parser.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from collections import defaultdict +from typing import Union + +from loguru import logger +from omegaconf import DictConfig, ListConfig + +from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType +from cosmos_predict1.utils.validator import Float, Int, OneOf + + +class LayerControlConfigParser: + """ + Parses a config to select layers, blocks, and subblocks to apply LoRA, PEFT, and other finegrained post-training techniques. + A base model is first loaded then edits (i.e. LoRA, unfreeze, etc.) are applied to the model. Currently, only LoRA is supported for to_q, to_k, to_v, to_out attention layers. + See: cosmos_predict1/diffusion/training/utils/peft/lora_config.py and LoRA diffusion post-training for an example of how to create and use a LoRA config. + The input config is a dictionary with the following keys: + - enabled: whether to apply the PEFT + - customization_type: default/global type of PEFT to apply (LoRA, unfreeze, etc.) + - rank: default/global LoRA rank + - scale: default/global LoRA scale + - edits: a list of model edits to apply. + - blocks: a regex to select the blocks to apply the edit to: eg: r'\b(0|1|25|26)\b' + - block_edit: a list of subblocks to apply the edit to: eg: ["FA[to_q, to_v]", "CA[to_q, to_v]"]. + Subblock names correspond to FA (Full-Attention), CA (Cross-Attention), FL (FinalLayer), and MLP modules as defined in general_dit.py, + and the layers (i.e to_q, to_k, to_v, etc.) are defined in corresponding modules in attention.py. + - customization_type: type of PEFT to apply for the edit (LoRA, unfreeze, etc.) - overrides the global customization_type if provided + - rank: LoRA rank - overrides the global rank for target blocks and subblocks if provided + - scale: LoRA scale - overrides the global scale for target blocks and subblocks if provided + """ + + SUBBLOCK_PATTERN = r"^(?P.+?)\[(?P[^\]]+)\]$" # determines the subblock type (i.e. "FA[...]") + LAYER_PATTERN = r"^(?P.+?)(?::(?P.+?))?(?::(?P[\d\.]+))?$" # determines the layer details (i.e. to_q:8:0.6 or to_q) + FINAL_LAYER_NAME = "final_layer" + DEFAULT_ALLOWED_TYPES = { # subblock type to layer types + "FA": {"to_q", "to_k", "to_v", "to_out", "ada1", "ada2"}, + "CA": {"to_q", "to_k", "to_v", "to_out", "ada1", "ada2"}, + "MLP": {"l1", "l2", "ada1", "ada2"}, + } + + DEFAULT_VALUE_CONSTRAINTS = ( + { # field to allowed ranges. these ranges are not prescriptive and can be adjusted as needed. + "blocks": {"min": 0, "max": 27}, + "rank": {"min": 1, "max": 512}, + "scale": {"min": 1e-5, "max": 64}, + } + ) + ALLOWED_TYPES_FINAL_LAYER = {"FL": {"l1", "ada1", "ada2"}} + + def __init__(self, config: Union[str, dict] = {}, allowed_types: dict = None, value_constraints: dict = None): + self.config = self._config_to_dict(config) + self.enabled = str(self.config.get("enabled", "False")).lower() in ( + "true", + "1", + "yes", + ) # if not set, assume disabled + if self.enabled and not self.config.get("customization_type", ""): + raise AttributeError("Must specify a top-level customization_type.") + self.default_customization_type = CustomizationType.from_value(self.config.get("customization_type", "")) + self.default_rank = self.config.get("rank", None) + self.default_scale = self.config.get("scale", None) + + self.allowed_types = allowed_types or self.DEFAULT_ALLOWED_TYPES + self.value_constraints = value_constraints or self.DEFAULT_VALUE_CONSTRAINTS + logger.info( + f"Creating layers config with allowed subblock + layer types: \n{self.allowed_types} and value constraints: \n{self.value_constraints}" + ) + self.allowed_types_final_layer = self.ALLOWED_TYPES_FINAL_LAYER + + self._set_validators() + + self.all_blocks_str = ( + ",".join( + str(i) + for i in range( + self.value_constraints.get("blocks").get("min"), self.value_constraints.get("blocks").get("max") + 1 + ) + ) + + "," + + self.FINAL_LAYER_NAME + ) + + self.edits_per_block = defaultdict(lambda: None) + + def _set_validators(self): + """ + Sets validators for blocks, subblocks, rank, and scale. + + Raises: + AttributeError: If value constraints are not properly defined. + """ + self.subblock_validator = OneOf(default="", options=self.allowed_types.keys()) + self.final_layer_validator = OneOf(default="", options=self.allowed_types_final_layer.keys()) + self.rank_validator = None + self.scale_validator = None + try: + self.rank_validator = Int( + default=0, + min=self.value_constraints.get("rank").get("min"), + max=self.value_constraints.get("rank").get("max"), + ) + self.scale_validator = Float( + default=0, + min=self.value_constraints.get("scale").get("min"), + max=self.value_constraints.get("scale").get("max"), + ) + except AttributeError: + raise AttributeError( + "Value Constraints dictionary must contain 'blocks', 'rank', and 'scale' attributes with 'min' and 'max' attributes for each" + ) + + def _config_to_dict(self, config): + """ + Convert the given config into a dictionary if provided as a string. + + Args: + config (Union[str, dict]): The configuration as a JSON string or dictionary. + + Returns: + dict: The configuration as a dictionary. + + Raises: + ValueError: If the JSON string is invalid. + TypeError: If the config is not a string or dictionary. + """ + if isinstance(config, str): + try: + config = json.loads(config) + except json.JSONDecodeError: + raise ValueError("Invalid JSON string provided") + elif not isinstance(config, (dict, DictConfig)): + raise TypeError(f"Config should be either a JSON string or a dictionary, but got {type(config)}") + return config + + def _parse_blocks_regex(self, regex): + """ + Parse the 'blocks' regex and return a set of matching block numbers. + Allowed block numbers: defined in value_constraints, plus 'final_layer' + + Args: + regex (str): The regex pattern to match block numbers. + + Returns: + set: A set of block numbers that match the regex. + + Raises: + ValueError: If the regex pattern is invalid or matches invalid block numbers. + Exception: If 'final_layer' is defined with other blocks. + """ + try: + block_matches = re.findall(regex, self.all_blocks_str) + block_numbers = set() + for match in block_matches: + match = match.strip() + if match == "final_layer": + block_numbers.add(match) + else: + try: + block_numbers.add(int(match)) + except ValueError: + raise ValueError(f"Invalid match found: '{match}' is neither an integer nor 'final_layer'.") + except re.error as e: + raise ValueError(f"Invalid regex pattern provided: {regex}. Error: {e}") + + # as final_layer contains a different block type than other blocks, must be defined separately + if "final_layer" in block_numbers and len(block_numbers) > 1: + raise Exception(f"Block 'final_layer' must be defined separately, but got: {block_numbers}") + + return block_numbers + + def _parse_subblocks( + self, + block_edit: list | ListConfig, + customization_type: str, + rank: int, + scale: float, + is_final_layer: bool = False, + ): + """Generate a dictionary of edits config by subblock. + + Args: + block_edit (list): List of representing subblocks to apply the edit to (i.e ["FA[to_q, to_v]", "CA[to_q, to_v]"]) + customization_type (str): The type of PEFT to apply. + rank (int): The LoRA rank. + scale (float): The LoRA scale. + is_final_layer (bool): Indicates if this edit is for the final layer. + + Returns: + defaultdict: A dictionary of subblock edits configs. + + Raises: + TypeError: If block_edit is not a list. + AttributeError: If subblock format is incorrect or layer format is invalid. + ValueError: If rank and scale values are not provided. + """ + sb_dict = defaultdict(lambda: None) + + if not isinstance(block_edit, (list, ListConfig)): + raise TypeError(f"Config 'block_edits' field must be a list, but got {type(block_edit)}") + + if is_final_layer: # final layer has different allowed layer names + subblock_validator = self.final_layer_validator + allowed_types = self.allowed_types_final_layer + else: + subblock_validator = self.subblock_validator + allowed_types = self.allowed_types + + for subblock in block_edit: + sb_name = None + params_list = None + try: + sb_match = re.match(self.SUBBLOCK_PATTERN, subblock) + sb_name = subblock_validator.validate(sb_match.group("subblock")) + params_str = sb_match.group("parameters") + params_list = params_str.replace(" ", "").split(",") + except AttributeError: + raise AttributeError("Incorrect sub-block format: must be [...]") + layer_validator = OneOf(default="", options=allowed_types.get(sb_name)) + + # for each parameter in the subblock config + layers_dict = defaultdict(lambda: None) + for param in params_list: + try: + layer_match = re.match(self.LAYER_PATTERN, param) + layer_name = layer_validator.validate(layer_match.group("layer")) + layer_rank = layer_match.group("rank") or rank or self.default_rank + layer_scale = layer_match.group("scale") or scale or self.default_scale + if not layer_rank or not layer_scale: + raise ValueError( + "Rank and scale values must be provided at default, sub-block, or layer level." + ) + layer_rank = self.rank_validator.validate(layer_rank) + layer_scale = self.scale_validator.validate(layer_scale) + + layers_dict[layer_name] = {"activate": True, "lora_rank": layer_rank, "lora_scale": layer_scale} + layers_dict["customization_type"] = customization_type or self.default_customization_type + sb_dict[sb_name] = dict(layers_dict) + except AttributeError: + raise AttributeError("Layer format must be :[:] (where is optional)") + + if sb_dict: + sb_dict["customization_type"] = customization_type or self.default_customization_type + return sb_dict + + def parse(self): + """ + Parse the loaded config into a dictionary of edit configs by block number. + + Returns: + dict: A dictionary of edit configs applied to each block. + + Raises: + Exception: If more than one edit is specified for a block. + """ + if not self.enabled: + return {} + + # for each edit in the config + for edit in self.config.get("edits", []): + blocks = self._parse_blocks_regex(edit["blocks"]) # get the blocks affected by edit + logger.info(f"Applying edits for blocks {blocks}") + block_edit = edit.get("block_edit", []) + customization_type = CustomizationType.from_value(edit.get("customization_type", "")) + rank = edit.get("rank", None) + scale = edit.get("scale", None) + is_final_layer = blocks == set([self.FINAL_LAYER_NAME]) + # get subblock config + sb_dict = self._parse_subblocks( + block_edit=block_edit, + customization_type=customization_type, + rank=rank, + scale=scale, + is_final_layer=is_final_layer, + ) + + # for each block in the edit + for block in blocks: + if sb_dict: + if self.edits_per_block[block]: + raise Exception(f"More than one edit specified for block {block}") + self.edits_per_block[block] = dict(sb_dict) + if self.edits_per_block: + self.edits_per_block["customization_type"] = self.default_customization_type + return dict(self.edits_per_block) diff --git a/cosmos_predict1/diffusion/training/utils/optim_instantiate.py b/cosmos_predict1/diffusion/training/utils/optim_instantiate.py new file mode 100644 index 0000000000000000000000000000000000000000..96f45a65ffd11fd472efd34287bb867f69379986 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/optim_instantiate.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +import torch +from torch import nn + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.fused_adam import FusedAdam + + +def get_regular_param_group(net: nn.Module): + """ + seperate the parameters of the network into two groups: decay and no_decay. + based on nano_gpt codebase. + """ + param_dict = {pn: p for pn, p in net.named_parameters()} + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + return decay_params, nodecay_params + + +def get_base_optimizer( + model: nn.Module, + lr: float, + weight_decay: float, + optim_type: str = "adamw", + sharding: bool = False, + **kwargs, +) -> torch.optim.Optimizer: + net_decay_param, net_nodecay_param = get_regular_param_group(model) + + num_decay_params = sum(p.numel() for p in net_decay_param) + num_nodecay_params = sum(p.numel() for p in net_nodecay_param) + net_param_total = num_decay_params + num_nodecay_params + log.critical(f"total num parameters : {net_param_total:,}") + + param_group = [ + { + "params": net_decay_param + net_nodecay_param, + "lr": lr, + "weight_decay": weight_decay, + }, + ] + + if optim_type == "adamw": + opt_cls = torch.optim.AdamW + elif optim_type == "fusedadam": + opt_cls = FusedAdam + else: + raise ValueError(f"Unknown optimizer type: {optim_type}") + + return opt_cls(param_group, **kwargs) + + +def get_base_scheduler( + optimizer: torch.optim.Optimizer, + model: nn.Module, + scheduler_config: dict, +): + net_scheduler = hydra.utils.instantiate(scheduler_config) + net_scheduler.model = model + + return torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=[ + net_scheduler.schedule, + ], + ) diff --git a/cosmos_predict1/diffusion/training/utils/peft/lora_attn.py b/cosmos_predict1/diffusion/training/utils/peft/lora_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..269ac32205117d820bb0f24f3dcac9e976439299 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/peft/lora_attn.py @@ -0,0 +1,261 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from einops import rearrange +from torch.utils.checkpoint import checkpoint +from transformer_engine.pytorch.attention import apply_rotary_pos_emb + +from cosmos_predict1.diffusion.module.attention import Attention +from cosmos_predict1.diffusion.training.utils.peft.lora_net import LoRALinearLayer, TELoRALinearLayer +from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType + +try: + from megatron.core import parallel_state + + USE_MEGATRON = True +except ImportError: + USE_MEGATRON = False + + +def enable_attn_lora(attn: Attention, peft_control: dict) -> None: + """ + Enable LoRA for the attention block based on the peft_control dictionary. + + Args: + attn (Attention): The attention block to configure. + peft_control (dict): Dictionary containing PEFT configuration. + """ + attn.peft_lora_enabled = False + if peft_control: + try: + if peft_control["customization_type"] == CustomizationType.LORA: + attn.peft_lora_enabled = True + else: + raise Exception(f"Unsupported Customization type {peft_control['customization_type']}") + except KeyError as e: + raise KeyError(f"peft_control dictionary expected to have attribute {e.args[0]}.") + + +def configure_attn_lora(attn: Attention, peft_control: dict) -> None: + """ + Configure LoRA for the attention block based on the peft_control dictionary. + + Args: + attn (Attention): The attention block to configure. + peft_control (dict): Dictionary containing PEFT configuration. + """ + try: + attn.q_lora_enabled = peft_control.get("to_q", {}).get("activate", False) + attn.k_lora_enabled = peft_control.get("to_k", {}).get("activate", False) + attn.v_lora_enabled = peft_control.get("to_v", {}).get("activate", False) + attn.out_lora_enabled = peft_control.get("to_out", {}).get("activate", False) + if attn.q_lora_enabled: + attn.q_lora_rank = peft_control["to_q"]["lora_rank"] + attn.q_lora_scale = float(peft_control["to_q"]["lora_scale"]) + if attn.k_lora_enabled: + attn.k_lora_rank = peft_control["to_k"]["lora_rank"] + attn.k_lora_scale = float(peft_control["to_k"]["lora_scale"]) + if attn.v_lora_enabled: + attn.v_lora_rank = peft_control["to_v"]["lora_rank"] + attn.v_lora_scale = float(peft_control["to_v"]["lora_scale"]) + if attn.out_lora_enabled: + attn.out_lora_rank = peft_control["to_out"]["lora_rank"] + attn.out_lora_scale = float(peft_control["to_out"]["lora_scale"]) + except KeyError as e: + raise KeyError(f"All layers (to_q, etc) specified must have attribute {e.args[0]}.") + except ValueError as e: + raise ValueError(f"Could not convert string to float: {e}") + + +def cal_qkv_lora( + self, + x: torch.Tensor, + context: torch.Tensor = None, + mask: torch.Tensor = None, + rope_emb: torch.Tensor = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + del kwargs + """ + Calculate the Q, K, V matrices with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_qkv. + + Args: + x (torch.Tensor): Input tensor. + context (torch.Tensor, optional): Context tensor + mask (torch.Tensor, optional): Mask tensor + rope_emb (torch.Tensor, optional): Rotary positional embedding + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The Q, K, V matrices. + """ + + q = self.to_q[0](x) + context = x if context is None else context + k = self.to_k[0](context) + v = self.to_v[0](context) + + if self.peft_lora_enabled: + try: + if self.q_lora_enabled: + q_lora = self.to_q_lora(x) + q = q + self.q_lora_scale * q_lora + if self.k_lora_enabled: + k_lora = self.to_k_lora(context) + k = k + self.k_lora_scale * k_lora + if self.v_lora_enabled: + v_lora = self.to_v_lora(context) + v = v + self.v_lora_scale * v_lora + except AttributeError as e: + raise AttributeError(f"lora enabled, but missing class attribute {e.args[0]} of Attention block") + + q, k, v = map( + lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads // self.tp_size, c=self.dim_head), + (q, k, v), + ) + + def apply_norm_and_rotary_pos_emb(q, k, v, rope_emb): + q = self.to_q[1](q) + k = self.to_k[1](k) + v = self.to_v[1](v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) + k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) + return q, k, v + + q, k, v = checkpoint(apply_norm_and_rotary_pos_emb, q, k, v, rope_emb, use_reentrant=False) + + return q, k, v + + +def cal_attn_lora(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ + Calculate the attention output with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_attn. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + v (torch.Tensor): Value tensor. + mask (torch.Tensor, optional): Mask tensor. + + Returns: + torch.Tensor: The attention output. + """ + if self.backend == "transformer_engine": + seq_dim = self.qkv_format.index("s") + assert ( + q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 + ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." + attn_out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V] + out = self.to_out(attn_out) + + if self.peft_lora_enabled and self.out_lora_enabled: + try: + out_lora = self.to_out_lora(attn_out) + out = out + self.out_lora_scale * out_lora + except AttributeError as e: + raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block") + + return out + elif self.backend == "torch": + attn_out = self.attn_op(q, k, v, mask=mask) # [B, Mq, H, V] + attn_out = rearrange(attn_out, " b ... n c -> b ... (n c)") + out = self.to_out(attn_out) + + if self.peft_lora_enabled and self.out_lora_enabled: + try: + out_lora = self.to_out_lora(attn_out) + out = out + self.out_lora_scale * out_lora + except AttributeError as e: + raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block") + + return out + else: + raise ValueError(f"Backend {self.backend} not found") + + +def build_attn_lora(attn: Attention, peft_control: dict) -> None: + """ + Configure, build and add LoRA layers to the attention block. + + Args: + attn (Attention): The attention block to add LoRA layers to. + peft_control (dict): Dictionary containing PEFT configuration. + """ + enable_attn_lora(attn, peft_control) + configure_attn_lora(attn, peft_control) + if attn.peft_lora_enabled: + query_dim = attn.query_dim + inner_dim = attn.inner_dim + context_dim = attn.context_dim + tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=False) if USE_MEGATRON else None + + if attn.tp_size == 1: + if attn.q_lora_enabled: + attn.to_q_lora = LoRALinearLayer(query_dim, inner_dim, rank=attn.q_lora_rank, linear=True) + if attn.k_lora_enabled: + attn.to_k_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.k_lora_rank, linear=True) + if attn.v_lora_enabled: + attn.to_v_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.v_lora_rank, linear=True) + if attn.out_lora_enabled: + attn.to_out_lora = LoRALinearLayer(inner_dim, query_dim, rank=attn.out_lora_rank, linear=True) + else: + sequence_parallel = getattr(parallel_state, "sequence_parallel", False) + if attn.q_lora_enabled: + attn.to_q_lora = TELoRALinearLayer( + query_dim, + inner_dim, + rank=attn.q_lora_rank, + linear=True, + tp_size=attn.tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode="column", + ) + if attn.k_lora_enabled: + attn.to_k_lora = TELoRALinearLayer( + context_dim, + inner_dim, + rank=attn.k_lora_rank, + linear=True, + tp_size=attn.tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode="column", + ) + if attn.v_lora_enabled: + attn.to_v_lora = TELoRALinearLayer( + context_dim, + inner_dim, + rank=attn.v_lora_rank, + linear=True, + tp_size=attn.tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode="column", + ) + if attn.out_lora_enabled: + attn.to_out_lora = TELoRALinearLayer( + inner_dim, + query_dim, + rank=attn.out_lora_rank, + linear=True, + tp_size=attn.tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode="row", + ) + attn.cal_qkv = cal_qkv_lora.__get__(attn, attn.__class__) + attn.cal_attn = cal_attn_lora.__get__(attn, attn.__class__) diff --git a/cosmos_predict1/diffusion/training/utils/peft/lora_attn_test.py b/cosmos_predict1/diffusion/training/utils/peft/lora_attn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cc52b1a55234eb4b3b6b0cfa84acf9f0be42b0af --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/peft/lora_attn_test.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + pytest -s cosmos_predict1/diffusion/training/utils/peft/lora_attn_test.py +""" + +import copy + +import pytest +import torch +import torch.nn as nn +from einops import rearrange, repeat +from loguru import logger + +from cosmos_predict1.diffusion.config.base.net import FADITV2Config +from cosmos_predict1.diffusion.training.utils.layer_control.peft_control_config_parser import LayerControlConfigParser +from cosmos_predict1.diffusion.training.utils.peft.peft import add_lora_layers, get_all_lora_params +from cosmos_predict1.utils.lazy_config import instantiate + + +class DummyModel(nn.Module): + def __init__(self): + super().__init__() + dummy_net = copy.deepcopy(FADITV2Config) + dummy_net.num_blocks = 2 + dummy_net.model_channels = 256 + dummy_net.num_heads = 8 + self.net = instantiate(dummy_net).cuda() + + +@pytest.fixture() +def block1_peft_control(): + """ + This config has the following edits for the following blocks: + Block 0: FA, CA edits for ALL sub-blocks + """ + config = { + "enabled": "True", + "edits": [ + { + "blocks": "\\b\\d*([1])\\b", + "customization_type": "LoRA", + "rank": 8, + "scale": 0.6, + "block_edit": [ + "FA[to_q:8:0.8, to_k:16:1.2, to_v:4:64, to_out:8]", + "CA[to_q:16, to_k:16, to_v:4, to_out:32]", + ], + }, + ], + "customization_type": "LoRA", + "rank": 8, + "scale": 0.8, + } + config_parser = LayerControlConfigParser(config) + return config_parser.parse() + + +def test_model_without_lora(): + model = DummyModel() + lora_params = get_all_lora_params(model) + actual = len(lora_params) + expected = 0 + assert actual == expected, f"Expected {expected} LoRA layers, got {actual}" + + +def test_model_with_lora(block1_peft_control): + model = DummyModel() + add_lora_layers(model, block1_peft_control) + lora_params = get_all_lora_params(model) + actual = len(lora_params) + expected = 16 + assert actual == expected, f"Expected {expected} LoRA layers, got {actual}" + + +def test_model_cal_qkv_lora_matches_base_version_at_init(block1_peft_control): + model = DummyModel() + # isolate a single attention layer + block_idx = 1 + attn = model.net.blocks[f"block{block_idx}"].blocks[0].block.attn + x = torch.rand(2, 16, 256).cuda() # batch size, seq len, embed size + + q_base, k_base, v_base = attn.cal_qkv(x) + add_lora_layers(model, block1_peft_control) + model.cuda() + q_lora, k_lora, v_lora = attn.cal_qkv(x) + + assert torch.allclose(q_base, q_lora) + assert torch.allclose(k_base, k_lora) + assert torch.allclose(v_base, v_lora) + + +def test_model_cal_qkv_lora_with_non_zero_lora(block1_peft_control): + model = DummyModel() + block_idx = 1 + self_attn = model.net.blocks[f"block{block_idx}"].blocks[0].block.attn + cross_attn = model.net.blocks[f"block{block_idx}"].blocks[1].block.attn + # Set q_norm and k_norm to Identity + for attn in [self_attn, cross_attn]: + attn.to_q[0].weight.data.fill_(0.1) + attn.to_k[0].weight.data.fill_(0.1) + attn.to_v[0].weight.data.fill_(0.1) + attn.to_q[1] = nn.Identity() # Set normalization to Identity + attn.to_k[1] = nn.Identity() + attn.to_v[1] = nn.Identity() + attn.to_q[1].cuda() + attn.to_k[1].cuda() + attn.to_v[1].cuda() + + q_base, k_base, v_base = {}, {}, {} + x = torch.ones(2, 16, 256).cuda() # batch size, seq len, embed size + cross_attn_context = torch.ones(2, 16, 1024).cuda() + context_dim = {"FA": 256, "CA": 1024} + input_context = {"FA": x, "CA": cross_attn_context} + + # Compute base qkv for both self and cross attention + for attn_name, attn in [("FA", self_attn), ("CA", cross_attn)]: + q_base[attn_name], k_base[attn_name], v_base[attn_name] = attn.cal_qkv(x, input_context[attn_name]) + # add lora layers + add_lora_layers(model, block1_peft_control) + model.cuda() + + # compute lora qkv with non-zero lora weights + for attn_name, attn in [("FA", self_attn), ("CA", cross_attn)]: + attn.to_q_lora.net[0].weight.data.fill_(0.1) + attn.to_q_lora.net[1].weight.data.fill_(0.2) + + attn.to_k_lora.net[0].weight.data.fill_(0.3) + attn.to_k_lora.net[1].weight.data.fill_(0.4) + + attn.to_v_lora.net[0].weight.data.fill_(0.5) + attn.to_v_lora.net[1].weight.data.fill_(0.6) + + q_lora, k_lora, v_lora = attn.cal_qkv(x, input_context[attn_name]) + + # Compare with expected lora qkv + self_attn_q_lora_scale = float( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_q", {}).get("lora_scale") + ) + self_attn_q_lora_rank = int( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_q", {}).get("lora_rank") + ) + q_lora_diff = 256 * 0.1 * self_attn_q_lora_rank * 0.2 + + self_attn_k_lora_scale = float( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_k", {}).get("lora_scale") + ) + self_attn_k_lora_rank = int( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_k", {}).get("lora_rank") + ) + k_lora_diff = context_dim[attn_name] * 0.3 * self_attn_k_lora_rank * 0.4 + + self_attn_v_lora_scale = float( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_v", {}).get("lora_scale") + ) + self_attn_v_lora_rank = int( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_v", {}).get("lora_rank") + ) + v_lora_diff = context_dim[attn_name] * 0.5 * self_attn_v_lora_rank * 0.6 + + expected_q_lora = q_base[attn_name] + self_attn_q_lora_scale * q_lora_diff + expected_k_lora = k_base[attn_name] + self_attn_k_lora_scale * k_lora_diff + expected_v_lora = v_base[attn_name] + self_attn_v_lora_scale * v_lora_diff + logger.info(f"attn_name: {attn_name}, q_lora: {q_lora.shape}, expected_q_lora: {expected_q_lora.shape}") + assert torch.allclose( + q_lora, expected_q_lora, rtol=1e-2 + ), f"q_lora: {q_lora[0, 0, 0, :2]}, expected_q_lora: {expected_q_lora[0, 0, 0, :2]}" + assert torch.allclose( + k_lora, expected_k_lora, rtol=1e-2 + ), f"k_lora: {k_lora[0, 0, 0, :2]}, expected_k_lora: {expected_k_lora[0, 0, 0, :2]}" + assert torch.allclose( + v_lora, expected_v_lora, rtol=1e-2 + ), f"v_lora: {v_lora[0, 0, 0, :2]}, expected_v_lora: {expected_v_lora[0, 0, 0, :2]}" + + +def test_model_cal_attn_lora_matches_base_version_at_init(block1_peft_control): + model = DummyModel() + q = torch.rand(2, 16, 8, 32).cuda() + k = torch.rand(2, 16, 8, 32).cuda() + v = torch.rand(2, 16, 8, 32).cuda() + + # isolate a single attention layer + block_idx = 1 + attn = model.net.blocks[f"block{block_idx}"].blocks[0].block.attn + attn_output_base = attn.cal_attn(q, k, v) # [2, 16, 256] + + add_lora_layers(model, block1_peft_control) + model.cuda() + attn_output_lora = attn.cal_attn(q, k, v) + + assert torch.allclose(attn_output_base, attn_output_lora) + + +def test_model_cal_attn_lora_with_non_zero_output_lora(block1_peft_control): + model = DummyModel() + block_idx = 1 + self_attn = model.net.blocks[f"block{block_idx}"].blocks[0].block.attn + cross_attn = model.net.blocks[f"block{block_idx}"].blocks[1].block.attn + for attn_name, attn in [("FA", self_attn), ("CA", cross_attn)]: + # Overwrite attn_op to return ones of shape [2, 16, 256] and output_dropout to be Identity + class OnesAttnOp(nn.Module): + def forward(self, *args, **kwargs): + return torch.ones([2, 16, 256]).cuda() + + attn.attn_op = OnesAttnOp() + attn.to_out[0].weight.data.fill_(0.1) + attn.to_out[1] = nn.Identity() # Remove dropout + + # Compute base attn output + q = torch.rand(2, 16, 8, 32).cuda() + k = torch.rand(2, 16, 8, 32).cuda() + v = torch.rand(2, 16, 8, 32).cuda() + attn_output_base = attn.cal_attn(q, k, v) + + # Add lora layers + add_lora_layers(model, block1_peft_control) + model.cuda() + # Set lora weights to non-zero + attn.to_out_lora.net[0].weight.data.fill_(0.1) + attn.to_out_lora.net[1].weight.data.fill_(0.2) + + # Compute lora attn output + attn_output_lora = attn.cal_attn(q, k, v) + + # Compare with expected lora attn output + output_lora_scale = float( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_out", {}).get("lora_scale") + ) + output_lora_rank = int( + block1_peft_control.get(block_idx, {}).get(attn_name, {}).get("to_out", {}).get("lora_rank") + ) + + expected_attn_output_lora = attn_output_base + output_lora_scale * 256 * 0.1 * output_lora_rank * 0.2 + assert torch.allclose( + attn_output_lora, expected_attn_output_lora, rtol=1e-2 + ), f"attn_output_lora: {attn_output_lora[0, 0, :2]}, expected_attn_output_lora: {expected_attn_output_lora[0, 0, :2]}" diff --git a/cosmos_predict1/diffusion/training/utils/peft/lora_config.py b/cosmos_predict1/diffusion/training/utils/peft/lora_config.py new file mode 100644 index 0000000000000000000000000000000000000000..896d32f396a36c126323170462f452691d5f1be0 --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/peft/lora_config.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_fa_ca_qv_lora_config(first_nblocks=28, rank=8, scale=1): + """ + Get a LoRA configuration for the Self-Attention (FA) and Cross-Attention (CA) blocks in the model. + This LoRA configuration is used to inject LoRA parameters into the model. + + Args: + first_nblocks (int): The number of blocks to apply LoRA to. + rank (int): The rank of the LoRA matrices. + """ + blocks_regex = r"\b(" + "|".join([str(i) for i in range(first_nblocks)]) + r")\b" + return dict( + enabled=True, + customization_type="LoRA", + rank=rank, + scale=scale, + edits=[ + dict( + blocks=blocks_regex, + customization_type="LoRA", + rank=rank, + scale=scale, + block_edit=[ + "FA[to_q, to_v]", + "CA[to_q, to_v]", + ], + ) + ], + ) diff --git a/cosmos_predict1/diffusion/training/utils/peft/lora_net.py b/cosmos_predict1/diffusion/training/utils/peft/lora_net.py new file mode 100644 index 0000000000000000000000000000000000000000..da7dbe2224613b4a52a4cd79707261c1b707a20b --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/peft/lora_net.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import transformer_engine as te +from megatron.core import parallel_state +from torch import nn + +from cosmos_predict1.utils import log + + +class LoRALinearLayer(nn.Module): + """ + ported from + https://github.com/huggingface/diffusers/blob/7a32b6beeb0cfdefed645253dce23d9b0a78597f/src/diffusers/models/attention_processor.py#L470. + """ + + def __init__(self, in_features, out_features, rank=4, linear=False): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + if linear: + down = nn.Linear(in_features, rank, bias=False) + up = nn.Linear(rank, out_features, bias=False) + else: + down = nn.Conv1d(in_features, rank, 1, bias=False) + up = nn.Conv1d(rank, out_features, 1, bias=False) + + nn.init.normal_(down.weight, std=1 / rank) + nn.init.zeros_(up.weight) + self.net = nn.Sequential(down, up) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.net[0].weight.dtype + + up_hidden_states = self.net(hidden_states.to(dtype)) + + return up_hidden_states.to(orig_dtype) + + +class TELoRALinearLayer(nn.Module): + """ + ported from + https://github.com/huggingface/diffusers/blob/7a32b6beeb0cfdefed645253dce23d9b0a78597f/src/diffusers/models/attention_processor.py#L470. + """ + + def __init__(self, in_features, out_features, rank, linear, tp_size, tp_group, sequence_parallel, parallel_mode): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + if linear: + down = te.pytorch.Linear( + in_features, + rank, + bias=False, + tp_size=1, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode=None, + ) + up = te.pytorch.Linear( + rank, + out_features, + bias=False, + tp_size=tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode=parallel_mode, + ) + else: + down = te.pytorch.Conv1d( + in_features, + rank, + 1, + bias=False, + tp_size=1, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode=None, + ) + up = te.pytorch.Conv1d( + rank, + out_features, + 1, + bias=False, + tp_size=tp_size, + tp_group=tp_group, + sequence_parallel=sequence_parallel, + parallel_mode=parallel_mode, + ) + tp_rank = parallel_state.get_tensor_model_parallel_rank() + # Create generator + gen = torch.Generator(device=down.weight.device) + # Save the current random state + gen_state = gen.get_state() + + # Set constant seed for non-tp layers + log.info(f"rank {tp_rank}: setting seed to 0") + gen.manual_seed(0) + nn.init.normal_(down.weight, std=1 / rank, generator=gen) + # Set a new random seed based on the tensor parallel rank + gen.manual_seed(tp_rank) + log.info(f"rank {tp_rank}: setting seed to {tp_rank}") + nn.init.zeros_(up.weight) + # Restore the original random state + gen.set_state(gen_state) + + self.net = nn.Sequential(down, up) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.net[0].weight.dtype + up_hidden_states = self.net(hidden_states.to(dtype)) + + return up_hidden_states.to(orig_dtype) diff --git a/cosmos_predict1/diffusion/training/utils/peft/peft.py b/cosmos_predict1/diffusion/training/utils/peft/peft.py new file mode 100644 index 0000000000000000000000000000000000000000..2540514d325976b8f84501c02e0e1ac043c9349d --- /dev/null +++ b/cosmos_predict1/diffusion/training/utils/peft/peft.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.diffusion.training.utils.peft.lora_attn import build_attn_lora +from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType +from cosmos_predict1.utils import log +from cosmos_predict1.utils.misc import count_params + + +def get_all_lora_params(model): + """ + Get all LoRA weight parameters in the model + """ + lora_modules = [mod for name, mod in model.named_modules() if "lora.net.0" in name or "lora.net.1" in name] + lora_params = [(name, param) for mod in lora_modules for name, param in mod.named_parameters()] + log.info(f"Found {len(lora_params)} LoRA weight matrices") + return lora_params + + +def setup_lora_requires_grad(model): + """ + Freeze all model parameters except LoRA parameters. + """ + num_param = count_params(model, verbose=True) + log.critical(f"Model has {num_param * 1e-6:.2f}M parameters before freezing") + lora_params = get_all_lora_params(model) + num_lora_param = sum([p.numel() for _, p in lora_params]) + log.info(f"Total number of LoRA parameters: {num_lora_param * 1e-6:.2f}M") + if num_lora_param > 0: + log.info("Freezing all parameters") + model.requires_grad_(False) + log.info("Unfreezing LoRA parameters") + for name, param in lora_params: + # log.info(f"Unfreezing loRA : {name}") + param.requires_grad_(True) + num_param = count_params(model, verbose=True) + log.critical(f"Model has {num_param * 1e-6:.2f}M parameters after freezing") + return num_lora_param + + +def add_lora_layers(model, peft_control_config): + for i, block_name in enumerate(model.net.blocks): + block = model.net.blocks[block_name] + peft_control = peft_control_config.get(i, {}) + for j, subblock in enumerate(block.blocks): + block_type = subblock.block_type + peft_control_subblock = peft_control.get(block_type.upper(), {}) + customization_type = peft_control_subblock.get("customization_type", None) + if customization_type == CustomizationType.LORA: + if block_type.upper() in ["CA", "FA"]: + build_attn_lora(subblock.block.attn, peft_control_subblock) diff --git a/cosmos_predict1/diffusion/types.py b/cosmos_predict1/diffusion/types.py new file mode 100644 index 0000000000000000000000000000000000000000..0459d88d423794a97fb62ee3275c57c9d1c2007a --- /dev/null +++ b/cosmos_predict1/diffusion/types.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class LabelImageCondition: + label: torch.Tensor + + def get_classifier_free_guidance_condition(self) -> LabelImageCondition: + return LabelImageCondition(torch.zeros_like(self.label)) + + +@dataclass +class DenoisePrediction: + x0: torch.Tensor # clean data prediction + eps: Optional[torch.Tensor] = None # noise prediction + logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty diff --git a/cosmos_predict1/diffusion/utils/customization/customization_manager.py b/cosmos_predict1/diffusion/utils/customization/customization_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc6d8446b7c443512e9fa97174a11d66253adb4 --- /dev/null +++ b/cosmos_predict1/diffusion/utils/customization/customization_manager.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from enum import Enum + + +class CustomizationType(Enum): + LORA = 1 + REPLACE = 2 + + @classmethod + def from_value(cls, value): + """Convert both int and str to the corresponding enum.""" + if isinstance(value, str): + value = value.lower() + if value == "lora": + return cls.LORA + elif value == "replace": + return cls.REPLACE + elif value == "": + return None + else: + raise ValueError("Customization type must be lora or replace") + raise TypeError("CustomizationType must be specified as a string.") diff --git a/cosmos_predict1/tokenizer/__init__.py b/cosmos_predict1/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/tokenizer/inference/__init__.py b/cosmos_predict1/tokenizer/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/tokenizer/inference/image_cli.py b/cosmos_predict1/tokenizer/inference/image_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..5da9961c150a4ae3d0b81c81d430308c76c231a4 --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/image_cli.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A CLI to run ImageTokenizer on plain images based on torch.jit. + +Usage: + python3 -m cosmos_predict1.tokenizer.inference.image_cli \ + --image_pattern 'path/to/input/folder/*.jpg' \ + --output_dir ./reconstructions \ + --checkpoint_enc ./checkpoints//encoder.jit \ + --checkpoint_dec ./checkpoints//decoder.jit + + Optionally, you can run the model in pure PyTorch mode: + python3 -m cosmos_predict1.tokenizer.inference.image_cli \ + --image_pattern 'path/to/input/folder/*.jpg' \ + --mode torch \ + --tokenizer_type CI8x8 \ + --checkpoint_enc ./checkpoints//encoder.jit \ + --checkpoint_dec ./checkpoints//decoder.jit +""" + +import os +import sys +from argparse import ArgumentParser, Namespace +from typing import Any + +import numpy as np +from loguru import logger as logging + +from cosmos_predict1.tokenizer.inference.image_lib import ImageTokenizer +from cosmos_predict1.tokenizer.inference.utils import ( + get_filepaths, + get_output_filepath, + read_image, + resize_image, + write_image, +) +from cosmos_predict1.tokenizer.networks import TokenizerConfigs + + +def _parse_args() -> tuple[Namespace, dict[str, Any]]: + parser = ArgumentParser(description="A CLI for running ImageTokenizer on plain images.") + parser.add_argument( + "--image_pattern", + type=str, + default="path/to/images/*.jpg", + help="Glob pattern.", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="JIT full Autoencoder model filepath.", + ) + parser.add_argument( + "--checkpoint_enc", + type=str, + default=None, + help="JIT Encoder model filepath.", + ) + parser.add_argument( + "--checkpoint_dec", + type=str, + default=None, + help="JIT Decoder model filepath.", + ) + parser.add_argument( + "--tokenizer_type", + type=str, + default=None, + choices=[ + "CI8x8-360p", + "CI16x16-360p", + "DI8x8-360p", + "DI16x16-360p", + ], + help="Specifies the tokenizer type.", + ) + parser.add_argument( + "--mode", + type=str, + choices=["torch", "jit"], + default="jit", + help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", + ) + parser.add_argument( + "--short_size", + type=int, + default=None, + help="The size to resample inputs. None, by default.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="Sets the precision. Default bfloat16.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for invoking the model.", + ) + parser.add_argument("--output_dir", type=str, default=None, help="Output directory.") + parser.add_argument( + "--save_input", + action="store_true", + help="If on, the input image will be be outputed too.", + ) + args = parser.parse_args() + return args + + +logging.info("Initializes args ...") +args = _parse_args() +if args.mode == "torch" and args.tokenizer_type is None: + logging.error("'torch' backend requires the tokenizer_type to be specified.") + sys.exit(1) + + +def _run_eval() -> None: + """Invokes the evaluation pipeline.""" + + if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None: + logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.") + return + + if args.mode == "torch": + _type = args.tokenizer_type.replace("-", "_") + _config = TokenizerConfigs[_type].value + else: + _config = None + + logging.info( + f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..." + ) + autoencoder = ImageTokenizer( + checkpoint=args.checkpoint, + checkpoint_enc=args.checkpoint_enc, + checkpoint_dec=args.checkpoint_dec, + tokenizer_config=_config, + device=args.device, + dtype=args.dtype, + ) + + filepaths = get_filepaths(args.image_pattern) + logging.info(f"Found {len(filepaths)} images from {args.image_pattern}.") + + for filepath in filepaths: + logging.info(f"Reading image {filepath} ...") + image = read_image(filepath) + image = resize_image(image, short_size=args.short_size) + batch_image = np.expand_dims(image, axis=0) + + logging.info("Invoking the autoencoder model in ... ") + output_image = autoencoder(batch_image)[0] + + output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) + logging.info(f"Outputing {output_filepath} ...") + write_image(output_filepath, output_image) + + if args.save_input: + ext = os.path.splitext(output_filepath)[-1] + input_filepath = output_filepath.replace(ext, "_input" + ext) + write_image(input_filepath, image) + + +@logging.catch(reraise=True) +def main() -> None: + _run_eval() + + +if __name__ == "__main__": + main() diff --git a/cosmos_predict1/tokenizer/inference/image_lib.py b/cosmos_predict1/tokenizer/inference/image_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..9929d871a2d5862b0ced9e4e937f28033124ba62 --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/image_lib.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A library for image tokenizers inference.""" + +from typing import Any + +import numpy as np +import torch + +from cosmos_predict1.tokenizer.inference.utils import ( + load_decoder_model, + load_encoder_model, + load_model, + numpy2tensor, + pad_image_batch, + tensor2numpy, + unpad_image_batch, +) + + +class ImageTokenizer(torch.nn.Module): + def __init__( + self, + checkpoint: str = None, + checkpoint_enc: str = None, + checkpoint_dec: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", + dtype: str = "bfloat16", + ) -> None: + super().__init__() + self._device = device + self._dtype = getattr(torch, dtype) + self._full_model = ( + load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None + ) + self._enc_model = ( + load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) + if checkpoint_enc is not None + else None + ) + self._dec_model = ( + load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) + if checkpoint_dec is not None + else None + ) + + @torch.no_grad() + def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Reconstrcuts a batch of image tensors after embedding into a latent. + + Args: + input_tensor: The input image Bx3xHxW layout, range [-1..1]. + Returns: + The reconstructed tensor, layout Bx3xHxW, range [-1..1]. + """ + if self._full_model is not None: + output_tensor = self._full_model(input_tensor) + output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor + else: + output_latent = self.encode(input_tensor)[0] + output_tensor = self.decode(output_latent) + return output_tensor + + @torch.no_grad() + def decode(self, input_latent: torch.Tensor) -> torch.Tensor: + """Decodes an image from a provided latent embedding. + + Args: + input_latent: The continuous latent Bx16xhxw for CI, + or the discrete indices Bxhxw for DI. + Returns: + The output tensor in Bx3xHxW, range [-1..1]. + """ + return self._dec_model(input_latent) + + @torch.no_grad() + def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: + """Encodes an image into a latent embedding or code. + + Args: + input_tensor: The input tensor Bx3xHxW layout, range [-1..1]. + Returns: + For continuous image (CI) tokenizer, the tuple contains: + - The latent embedding, Bx16x(h)x(w), where the compression + rate is (H/h x W/w), and channel dimension of 16. + For discrete image (DI) tokenizer, the tuple contains: + - The indices, Bx(h)x(w), from a codebook of size 64K, which + corresponds to FSQ levels of (8,8,8,5,5,5). + - The discrete code, Bx6x(h)x(w), where the compression rate is + again (H/h x W/w), and channel dimension of 6. + """ + output_latent = self._enc_model(input_tensor) + if isinstance(output_latent, torch.Tensor): + return output_latent + return output_latent[:-1] + + @torch.no_grad() + def forward(self, image: np.ndarray) -> np.ndarray: + """Reconstructs an image using a pre-trained tokenizer. + + Args: + image: The input image BxHxWxC layout, range [0..255]. + Returns: + The reconstructed image in range [0..255], layout BxHxWxC. + """ + padded_input_image, crop_region = pad_image_batch(image) + input_tensor = numpy2tensor(padded_input_image, dtype=self._dtype, device=self._device) + output_tensor = self.autoencode(input_tensor) + padded_output_image = tensor2numpy(output_tensor) + return unpad_image_batch(padded_output_image, crop_region) diff --git a/cosmos_predict1/tokenizer/inference/utils.py b/cosmos_predict1/tokenizer/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53feab8e282207a4810a32ce0f2d7ceb29273623 --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/utils.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for the inference libraries.""" + +import os +from glob import glob +from typing import Any + +import mediapy as media +import numpy as np +import torch + +from cosmos_predict1.tokenizer.networks import TokenizerModels + +_DTYPE, _DEVICE = torch.bfloat16, "cuda" +_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) +_SPATIAL_ALIGN = 16 +_TEMPORAL_ALIGN = 8 + + +def load_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + full_model.load_state_dict(ckpts.state_dict(), strict=True) + return full_model.eval().to(device) + + +def load_encoder_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + encoder_model = full_model.encoder_jit() + encoder_model.load_state_dict(ckpts.state_dict(), strict=True) + return encoder_model.eval().to(device) + + +def load_decoder_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + decoder_model = full_model.decoder_jit() + decoder_model.load_state_dict(ckpts.state_dict(), strict=True) + return decoder_model.eval().to(device) + + +def _load_pytorch_model( + jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda" +) -> torch.nn.Module: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + tokenizer_name = tokenizer_config["name"] + model = TokenizerModels[tokenizer_name].value(**tokenizer_config) + ckpts = torch.jit.load(jit_filepath, map_location=device) + return model, ckpts + + +def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: + """Loads a torch.jit.ScriptModule from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + model = torch.jit.load(jit_filepath, map_location=device) + return model.eval().to(device) + + +def save_jit_model( + model: torch.jit.ScriptModule | torch.jit.RecursiveScriptModule = None, + jit_filepath: str = None, +) -> None: + """Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file. + + Args: + model: JIT compiled model loaded onto `config.checkpoint.jit.device`. + jit_filepath: The filepath to the JIT-compiled model. + """ + torch.jit.save(model, jit_filepath) + + +def get_filepaths(input_pattern) -> list[str]: + """Returns a list of filepaths from a pattern.""" + filepaths = sorted(glob(str(input_pattern))) + return list(set(filepaths)) + + +def get_output_filepath(filepath: str, output_dir: str = None) -> str: + """Returns the output filepath for the given input filepath.""" + output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions" + output_filepath = f"{output_dir}/{os.path.basename(filepath)}" + os.makedirs(output_dir, exist_ok=True) + return output_filepath + + +def read_image(filepath: str) -> np.ndarray: + """Reads an image from a filepath. + + Args: + filepath: The filepath to the image. + + Returns: + The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype. + """ + image = media.read_image(filepath) + # convert the grey scale image to RGB + # since our tokenizers always assume 3-channel RGB image + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + # convert RGBA to RGB + if image.shape[-1] == 4: + image = image[..., :3] + return image + + +def read_video(filepath: str) -> np.ndarray: + """Reads a video from a filepath. + + Args: + filepath: The filepath to the video. + Returns: + The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype. + """ + video = media.read_video(filepath) + # convert the grey scale frame to RGB + # since our tokenizers always assume 3-channel video + if video.ndim == 3: + video = np.stack([video] * 3, axis=-1) + # convert RGBA to RGB + if video.shape[-1] == 4: + video = video[..., :3] + return video + + +def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray: + """Resizes an image to have the short side of `short_size`. + + Args: + image: The image to resize, layout HxWxC, of any range. + short_size: The size of the short side. + Returns: + The resized image. + """ + if short_size is None: + return image + height, width = image.shape[-3:-1] + if height <= width: + height_new, width_new = short_size, int(width * short_size / height + 0.5) + width_new = width_new if width_new % 2 == 0 else width_new + 1 + else: + height_new, width_new = ( + int(height * short_size / width + 0.5), + short_size, + ) + height_new = height_new if height_new % 2 == 0 else height_new + 1 + return media.resize_image(image, shape=(height_new, width_new)) + + +def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray: + """Resizes a video to have the short side of `short_size`. + + Args: + video: The video to resize, layout TxHxWxC, of any range. + short_size: The size of the short side. + Returns: + The resized video. + """ + if short_size is None: + return video + height, width = video.shape[-3:-1] + if height <= width: + height_new, width_new = short_size, int(width * short_size / height + 0.5) + width_new = width_new if width_new % 2 == 0 else width_new + 1 + else: + height_new, width_new = ( + int(height * short_size / width + 0.5), + short_size, + ) + height_new = height_new if height_new % 2 == 0 else height_new + 1 + return media.resize_video(video, shape=(height_new, width_new)) + + +def write_image(filepath: str, image: np.ndarray): + """Writes an image to a filepath.""" + return media.write_image(filepath, image) + + +def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None: + """Writes a video to a filepath.""" + return media.write_video(filepath, video, fps=fps) + + +def numpy2tensor( + input_image: np.ndarray, + dtype: torch.dtype = _DTYPE, + device: str = _DEVICE, + range_min: int = -1, +) -> torch.Tensor: + """Converts image(dtype=np.uint8) to `dtype` in range [0..255]. + + Args: + input_image: A batch of images in range [0..255], BxHxWx3 layout. + Returns: + A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype. + """ + ndim = input_image.ndim + indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1] + image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F + if range_min == -1: + image = 2.0 * image - 1.0 + return torch.from_numpy(image).to(dtype).to(device) + + +def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: + """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. + + Args: + input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. + Returns: + A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. + """ + if range_min == -1: + input_tensor = (input_tensor.float() + 1.0) / 2.0 + ndim = input_tensor.ndim + output_image = input_tensor.clamp(0, 1).cpu().numpy() + output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) + return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) + + +def pad_image_batch(batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN) -> tuple[np.ndarray, list[int]]: + """Pads a batch of images to be divisible by `spatial_align`. + + Args: + batch: The batch of images to pad, layout BxHxWx3, in any range. + align: The alignment to pad to. + Returns: + The padded batch and the crop region. + """ + height, width = batch.shape[1:3] + align = spatial_align + height_to_pad = (align - height % align) if height % align != 0 else 0 + width_to_pad = (align - width % align) if width % align != 0 else 0 + + crop_region = [ + height_to_pad >> 1, + width_to_pad >> 1, + height + (height_to_pad >> 1), + width + (width_to_pad >> 1), + ] + batch = np.pad( + batch, + ( + (0, 0), + (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), + (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), + (0, 0), + ), + mode="constant", + ) + return batch, crop_region + + +def pad_video_batch( + batch: np.ndarray, + temporal_align: int = _TEMPORAL_ALIGN, + spatial_align: int = _SPATIAL_ALIGN, +) -> tuple[np.ndarray, list[int]]: + """Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`. + + Zero pad spatially. Reflection pad temporally to handle causality better. + Args: + batch: The batch of videos to pad., layout BxFxHxWx3, in any range. + align: The alignment to pad to. + Returns: + The padded batch and the crop region. + """ + num_frames, height, width = batch.shape[-4:-1] + align = spatial_align + height_to_pad = (align - height % align) if height % align != 0 else 0 + width_to_pad = (align - width % align) if width % align != 0 else 0 + + align = temporal_align + frames_to_pad = (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0 + + crop_region = [ + frames_to_pad >> 1, + height_to_pad >> 1, + width_to_pad >> 1, + num_frames + (frames_to_pad >> 1), + height + (height_to_pad >> 1), + width + (width_to_pad >> 1), + ] + batch = np.pad( + batch, + ( + (0, 0), + (0, 0), + (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), + (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), + (0, 0), + ), + mode="constant", + ) + batch = np.pad( + batch, + ( + (0, 0), + (frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)), + (0, 0), + (0, 0), + (0, 0), + ), + mode="edge", + ) + return batch, crop_region + + +def unpad_video_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: + """Unpads video with `crop_region`. + + Args: + batch: A batch of numpy videos, layout BxFxHxWxC. + crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices. + + Returns: + np.ndarray: Cropped numpy video, layout BxFxHxWxC. + """ + assert len(crop_region) == 6, "crop_region should be len of 6." + f1, y1, x1, f2, y2, x2 = crop_region + return batch[..., f1:f2, y1:y2, x1:x2, :] + + +def unpad_image_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: + """Unpads image with `crop_region`. + + Args: + batch: A batch of numpy images, layout BxHxWxC. + crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices. + + Returns: + np.ndarray: Cropped numpy image, layout BxHxWxC. + """ + assert len(crop_region) == 4, "crop_region should be len of 4." + y1, x1, y2, x2 = crop_region + return batch[..., y1:y2, x1:x2, :] diff --git a/cosmos_predict1/tokenizer/inference/video_cli.py b/cosmos_predict1/tokenizer/inference/video_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..51e17903c6b6b825c78c2e40760ea88aff51141f --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/video_cli.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A CLI to run CausalVideoTokenizer on plain videos based on torch.jit. + +Usage: + python3 -m cosmos_predict1.tokenizer.inference.video_cli \ + --video_pattern 'path/to/video/samples/*.mp4' \ + --output_dir ./reconstructions \ + --checkpoint_enc ./checkpoints//encoder.jit \ + --checkpoint_dec ./checkpoints//decoder.jit + + Optionally, you can run the model in pure PyTorch mode: + python3 -m cosmos_predict1.tokenizer.inference.video_cli \ + --video_pattern 'path/to/video/samples/*.mp4' \ + --mode=torch \ + --tokenizer_type=CV \ + --temporal_compression=4 \ + --spatial_compression=8 \ + --checkpoint_enc ./checkpoints//encoder.jit \ + --checkpoint_dec ./checkpoints//decoder.jit +""" + +import os +import sys +from argparse import ArgumentParser, Namespace +from typing import Any + +import numpy as np +from loguru import logger as logging + +from cosmos_predict1.tokenizer.inference.utils import ( + get_filepaths, + get_output_filepath, + read_video, + resize_video, + write_video, +) +from cosmos_predict1.tokenizer.inference.video_lib import CausalVideoTokenizer +from cosmos_predict1.tokenizer.networks import TokenizerConfigs + + +def _parse_args() -> tuple[Namespace, dict[str, Any]]: + parser = ArgumentParser(description="A CLI for CausalVideoTokenizer.") + parser.add_argument( + "--video_pattern", + type=str, + default="path/to/videos/*.mp4", + help="Glob pattern.", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="JIT full Autoencoder model filepath.", + ) + parser.add_argument( + "--checkpoint_enc", + type=str, + default=None, + help="JIT Encoder model filepath.", + ) + parser.add_argument( + "--checkpoint_dec", + type=str, + default=None, + help="JIT Decoder model filepath.", + ) + parser.add_argument( + "--tokenizer_type", + type=str, + default=None, + choices=[ + "CV8x8x8-720p", + "DV8x16x16-720p", + "CV4x8x8-360p", + "DV4x8x8-360p", + ], + help="Specifies the tokenizer type.", + ) + parser.add_argument( + "--mode", + type=str, + choices=["torch", "jit"], + default="jit", + help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", + ) + parser.add_argument( + "--short_size", + type=int, + default=None, + help="The size to resample inputs. None, by default.", + ) + parser.add_argument( + "--temporal_window", + type=int, + default=17, + help="The temporal window to operate at a time.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="Sets the precision, default bfloat16.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for invoking the model.", + ) + parser.add_argument("--output_dir", type=str, default=None, help="Output directory.") + parser.add_argument( + "--output_fps", + type=float, + default=24.0, + help="Output frames-per-second (FPS).", + ) + parser.add_argument( + "--save_input", + action="store_true", + help="If on, the input video will be be outputted too.", + ) + args = parser.parse_args() + return args + + +logging.info("Initializes args ...") +args = _parse_args() +if args.mode == "torch" and args.tokenizer_type is None: + logging.error("`torch` backend requires `--tokenizer_type` to be specified.") + sys.exit(1) + + +def _run_eval() -> None: + """Invokes JIT-compiled CausalVideoTokenizer on an input video.""" + + if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None: + logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.") + return + + if args.mode == "torch": + _type = args.tokenizer_type.replace("-", "_") + _config = TokenizerConfigs[_type].value + else: + _config = None + + logging.info( + f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..." + ) + autoencoder = CausalVideoTokenizer( + checkpoint=args.checkpoint, + checkpoint_enc=args.checkpoint_enc, + checkpoint_dec=args.checkpoint_dec, + tokenizer_config=_config, + device=args.device, + dtype=args.dtype, + ) + + logging.info(f"Looking for files matching video_pattern={args.video_pattern} ...") + filepaths = get_filepaths(args.video_pattern) + logging.info(f"Found {len(filepaths)} videos from {args.video_pattern}.") + + for filepath in filepaths: + logging.info(f"Reading video {filepath} ...") + video = read_video(filepath) + video = resize_video(video, short_size=args.short_size) + + logging.info("Invoking the autoencoder model in ... ") + batch_video = video[np.newaxis, ...] + output_video = autoencoder(batch_video, temporal_window=args.temporal_window)[0] + logging.info("Constructing output filepath ...") + output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) + logging.info(f"Outputing {output_filepath} ...") + write_video(output_filepath, output_video, fps=args.output_fps) + if args.save_input: + ext = os.path.splitext(output_filepath)[-1] + input_filepath = output_filepath.replace(ext, "_input" + ext) + write_video(input_filepath, video, fps=args.output_fps) + + +@logging.catch(reraise=True) +def main() -> None: + _run_eval() + + +if __name__ == "__main__": + main() diff --git a/cosmos_predict1/tokenizer/inference/video_lib.py b/cosmos_predict1/tokenizer/inference/video_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe8d2ec1b52e01c4823a1f660133a529762744d --- /dev/null +++ b/cosmos_predict1/tokenizer/inference/video_lib.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A library for Causal Video Tokenizer inference.""" + +from typing import Any + +import numpy as np +import torch +from tqdm import tqdm + +from cosmos_predict1.tokenizer.inference.utils import ( + load_decoder_model, + load_encoder_model, + load_model, + numpy2tensor, + pad_video_batch, + tensor2numpy, + unpad_video_batch, +) + + +class CausalVideoTokenizer(torch.nn.Module): + def __init__( + self, + checkpoint: str = None, + checkpoint_enc: str = None, + checkpoint_dec: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", + dtype: str = "bfloat16", + ) -> None: + super().__init__() + self._device = device + self._dtype = getattr(torch, dtype) + self._full_model = ( + load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None + ) + self._enc_model = ( + load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) + if checkpoint_enc is not None + else None + ) + self._dec_model = ( + load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) + if checkpoint_dec is not None + else None + ) + + @torch.no_grad() + def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Reconstrcuts a batch of video tensors after embedding into a latent. + + Args: + video: The input video Bx3xTxHxW layout, range [-1..1]. + Returns: + The reconstructed video, layout Bx3xTxHxW, range [-1..1]. + """ + if self._full_model is not None: + output_tensor = self._full_model(input_tensor) + output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor + else: + output_latent = self.encode(input_tensor)[0] + output_tensor = self.decode(output_latent) + return output_tensor + + @torch.no_grad() + def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: + """Encodes a numpy video into a CausalVideo latent or code. + + Args: + input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1]. + Returns: + For causal continuous video (CV) tokenizer, the tuple contains: + - The latent embedding, Bx16x(t)x(h)x(w), where the compression + rate is (T/t x H/h x W/w), and channel dimension of 16. + For causal discrete video (DV) tokenizer, the tuple contains: + 1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which + is formed by FSQ levels of (8,8,8,5,5,5). + 2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate + is again (T/t x H/h x W/w), and channel dimension of 6. + """ + assert input_tensor.ndim == 5, "input video should be of 5D." + + output_latent = self._enc_model(input_tensor) + if isinstance(output_latent, torch.Tensor): + return output_latent + return output_latent[:-1] + + @torch.no_grad() + def decode(self, input_latent: torch.Tensor) -> torch.Tensor: + """Encodes a numpy video into a CausalVideo latent. + + Args: + input_latent: The continuous latent Bx16xtxhxw for CV, + or the discrete indices Bxtxhxw for DV. + Returns: + The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1]. + """ + assert input_latent.ndim >= 4, "input latent should be of 5D for continuous and 4D for discrete." + return self._dec_model(input_latent) + + def forward( + self, + video: np.ndarray, + temporal_window: int = 17, + ) -> np.ndarray: + """Reconstructs video using a pre-trained CausalTokenizer autoencoder. + Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer + in a sliding manner with a `temporal_window` size. + + Args: + video: The input video BxTxHxWx3 layout, range [0..255]. + temporal_window: The length of the temporal window to process, default=25. + Returns: + The reconstructed video in range [0..255], layout BxTxHxWx3. + """ + assert video.ndim == 5, "input video should be of 5D." + num_frames = video.shape[1] # can be of any length. + output_video_list = [] + for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)): + # Input video for the current window. + start, end = idx * temporal_window, (idx + 1) * temporal_window + input_video = video[:, start:end, ...] + + # Spatio-temporally pad input_video so it's evenly divisible. + padded_input_video, crop_region = pad_video_batch(input_video) + input_tensor = numpy2tensor(padded_input_video, dtype=self._dtype, device=self._device) + output_tensor = self.autoencode(input_tensor) + padded_output_video = tensor2numpy(output_tensor) + output_video = unpad_video_batch(padded_output_video, crop_region) + + output_video_list.append(output_video) + return np.concatenate(output_video_list, axis=1) diff --git a/cosmos_predict1/tokenizer/modules/__init__.py b/cosmos_predict1/tokenizer/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a15e15e4bfa8b0a33140ce78830be28defb91238 --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/__init__.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +from cosmos_predict1.tokenizer.modules.distributions import GaussianDistribution, IdentityDistribution +from cosmos_predict1.tokenizer.modules.layers2d import Decoder, Encoder +from cosmos_predict1.tokenizer.modules.layers3d import DecoderBase, DecoderFactorized, EncoderBase, EncoderFactorized +from cosmos_predict1.tokenizer.modules.quantizers import FSQuantizer, LFQuantizer, ResidualFSQuantizer, VectorQuantizer + + +class EncoderType(Enum): + Default = Encoder + + +class DecoderType(Enum): + Default = Decoder + + +class Encoder3DType(Enum): + BASE = EncoderBase + FACTORIZED = EncoderFactorized + + +class Decoder3DType(Enum): + BASE = DecoderBase + FACTORIZED = DecoderFactorized + + +class ContinuousFormulation(Enum): + VAE = GaussianDistribution + AE = IdentityDistribution + + +class DiscreteQuantizer(Enum): + VQ = VectorQuantizer + LFQ = LFQuantizer + FSQ = FSQuantizer + RESFSQ = ResidualFSQuantizer diff --git a/cosmos_predict1/tokenizer/modules/distributions.py b/cosmos_predict1/tokenizer/modules/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..2347f7453611d9fea87d0f530bd8e54f02c3f39e --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/distributions.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The distribution modes to use for continuous image tokenizers.""" + +import torch + + +class IdentityDistribution(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, parameters): + return parameters, (torch.tensor([0.0]), torch.tensor([0.0])) + + +class GaussianDistribution(torch.nn.Module): + def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0): + super().__init__() + self.min_logvar = min_logvar + self.max_logvar = max_logvar + + def sample(self, mean, logvar): + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + def forward(self, parameters): + mean, logvar = torch.chunk(parameters, 2, dim=1) + logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar) + return self.sample(mean, logvar), (mean, logvar) diff --git a/cosmos_predict1/tokenizer/modules/layers2d.py b/cosmos_predict1/tokenizer/modules/layers2d.py new file mode 100644 index 0000000000000000000000000000000000000000..5770bcf62f45468568fd3c99e22d9d4c9582d38f --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/layers2d.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The model definition for Continuous 2D layers + +Adapted from: https://github.com/CompVis/stable-diffusion/blob/ +21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py + +[Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors] +https://github.com/CompVis/stable-diffusion/blob/ +21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/LICENSE +""" + +import math + +import numpy as np + +# pytorch_diffusion + derived encoder decoder +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger as logging + +from cosmos_predict1.tokenizer.modules.patching import Patcher, UnPatcher +from cosmos_predict1.tokenizer.modules.utils import Normalize, nonlinearity + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) + return self.conv(x) + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + return self.conv(x) + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + **kwargs, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = Normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nin_shortcut = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + + self.norm = Normalize(in_channels) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Encoder(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size + + # calculate the number of downsample operations + self.num_downsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert ( + self.num_downsamples <= self.num_resolutions + ), f"we can only downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, channels, kernel_size=3, stride=1, padding=1) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level < self.num_downsamples: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level < self.num_downsamples: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: int, + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size + + # calculate the number of upsample operations + self.num_upsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert self.num_upsamples <= self.num_resolutions, f"we can only upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level >= (self.num_resolutions - self.num_upsamples): + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level >= (self.num_resolutions - self.num_upsamples): + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher(h) + return h diff --git a/cosmos_predict1/tokenizer/modules/layers2d_test.py b/cosmos_predict1/tokenizer/modules/layers2d_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1fac5a0015b842febd68f3ee9e153741fec9febf --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/layers2d_test.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The test for model definition of 2D layers + +PYTHONPATH=$PWD pytest -v cosmos_predict1/tokenizer/modules/layers2d_test.py +""" +import os + +import numpy as np +import pytest +import torch +from torchvision.transforms import CenterCrop + +from cosmos_predict1.tokenizer.inference.image_lib import ImageTokenizer +from cosmos_predict1.tokenizer.inference.utils import read_image +from cosmos_predict1.tokenizer.networks import TokenizerConfigs + +# test configs +TEST_CONFIGS = [ + ("CI8x8-360p", "checkpoints/Cosmos-Tokenize1-CI8x8-360p"), + ("CI16x16-360p", "checkpoints/Cosmos-Tokenize1-CI16x16-360p"), + ("DI8x8-360p", "checkpoints/Cosmos-Tokenize1-DI8x8-360p"), + ("DI16x16-360p", "checkpoints/Cosmos-Tokenize1-DI16x16-360p"), +] + + +@pytest.fixture(scope="module") +def image_tensor(): + image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "test_data", "image.png") + print(f"image_path: {image_path}") + image = read_image(image_path) + + assert image.shape[0] >= 512, "Image height should be at least 512 pixels" + assert image.shape[1] >= 512, "Image width should be at least 512 pixels" + assert image.shape[2] == 3, "Image should have 3 channels" + + input_tensor = CenterCrop(512)( + torch.from_numpy(image[np.newaxis, ...]).to("cuda").to(torch.bfloat16).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0 + ) + return input_tensor + + +@pytest.mark.parametrize("config", TEST_CONFIGS) +def test_tokenizer(config, image_tensor): + name, model_id = config + continuous = name.startswith(("C", "c")) + [ + spatial_compression, + ] = list(map(int, name[2:].split("x")[:1])) + print(f"\nTesting tokenizer: {model_id}") + print(f"spatial_compression={spatial_compression}") + + _config = TokenizerConfigs[name.replace("-", "_")].value + autoencoder = ImageTokenizer( + checkpoint_enc=f"{model_id}/encoder.jit", + checkpoint_dec=f"{model_id}/decoder.jit", + tokenizer_config=_config, + device="cuda", + dtype="bfloat16", + ) + + try: + # Test shape check + reconstructed_tensor = auto_shape_check(image_tensor, autoencoder, spatial_compression, continuous) + finally: + # Cleanup + del autoencoder + del reconstructed_tensor + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def auto_shape_check(input_tensor, autoencoder, spatial_compression, continuous): + if continuous: + (latent,) = autoencoder.encode(input_tensor) + torch.testing.assert_close(latent.shape, (1, 16, 512 // spatial_compression, 512 // spatial_compression)) + reconstructed_tensor = autoencoder.decode(latent) + else: + (indices, codes) = autoencoder.encode(input_tensor) + torch.testing.assert_close(indices.shape, (1, 512 // spatial_compression, 512 // spatial_compression)) + torch.testing.assert_close(codes.shape, (1, 6, 512 // spatial_compression, 512 // spatial_compression)) + reconstructed_tensor = autoencoder.decode(indices) + + torch.testing.assert_close(reconstructed_tensor.shape, input_tensor.shape) + return reconstructed_tensor diff --git a/cosmos_predict1/tokenizer/modules/layers3d.py b/cosmos_predict1/tokenizer/modules/layers3d.py new file mode 100644 index 0000000000000000000000000000000000000000..4d12c37240d6f3c5e1d38870fa4b9099c5167b2e --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/layers3d.py @@ -0,0 +1,949 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The model definition for 3D layers + +Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/ +9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889 + +[MIT License Copyright (c) 2023 Phil Wang] +https://github.com/lucidrains/magvit2-pytorch/blob/ +9f49074179c912736e617d61b32be367eb5f993a/LICENSE +""" +import math +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger as logging + +from cosmos_predict1.tokenizer.modules.patching import Patcher, Patcher3D, UnPatcher, UnPatcher3D +from cosmos_predict1.tokenizer.modules.utils import ( + CausalNormalize, + batch2space, + batch2time, + cast_tuple, + is_odd, + nonlinearity, + replication_pad, + space2batch, + time2batch, +) + +_LEGACY_NUM_GROUPS = 32 + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in: int = 1, + chan_out: int = 1, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + pad_mode: str = "constant", + **kwargs, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + time_stride = kwargs.pop("time_stride", 1) + time_dilation = kwargs.pop("time_dilation", 1) + padding = kwargs.pop("padding", 1) + + self.pad_mode = pad_mode + time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride) + self.time_pad = time_pad + + self.spatial_pad = (padding, padding, padding, padding) + + stride = (time_stride, stride, stride) + dilation = (time_dilation, dilation, dilation) + self.conv3d = nn.Conv3d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs, + ) + + def _replication_pad(self, x: torch.Tensor) -> torch.Tensor: + x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1) + x = torch.cat([x_prev, x], dim=2) + padding = self.spatial_pad + (0, 0) + return F.pad(x, padding, mode=self.pad_mode, value=0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._replication_pad(x) + return self.conv3d(x) + + +class CausalUpsample3d(nn.Module): + def __init__(self, in_channels: int) -> None: + super().__init__() + self.conv = CausalConv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + x = self.conv(x) + return x[..., int(time_factor - 1) :, :, :] + + +class CausalDownsample3d(nn.Module): + def __init__(self, in_channels: int) -> None: + super().__init__() + self.conv = CausalConv3d( + in_channels, + in_channels, + kernel_size=3, + stride=2, + time_stride=2, + padding=0, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x = replication_pad(x) + x = self.conv(x) + return x + + +class CausalHybridUpsample3d(nn.Module): + def __init__( + self, + in_channels: int, + spatial_up: bool = True, + temporal_up: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.conv1 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=1, padding=0) + if temporal_up + else nn.Identity() + ) + self.conv2 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=1, time_stride=1, padding=1) + if spatial_up + else nn.Identity() + ) + self.conv3 = ( + CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0) + if spatial_up or temporal_up + else nn.Identity() + ) + self.spatial_up = spatial_up + self.temporal_up = temporal_up + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_up and not self.temporal_up: + return x + + # hybrid upsample temporally. + if self.temporal_up: + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + x = x[..., int(time_factor - 1) :, :, :] + x = self.conv1(x) + x + + # hybrid upsample spatially. + if self.spatial_up: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + x = self.conv2(x) + x + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalHybridDownsample3d(nn.Module): + def __init__( + self, + in_channels: int, + spatial_down: bool = True, + temporal_down: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.conv1 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=2, time_stride=1, padding=0) + if spatial_down + else nn.Identity() + ) + self.conv2 = ( + CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=2, padding=0) + if temporal_down + else nn.Identity() + ) + self.conv3 = ( + CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0) + if spatial_down or temporal_down + else nn.Identity() + ) + + self.spatial_down = spatial_down + self.temporal_down = temporal_down + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_down and not self.temporal_down: + return x + + # hybrid downsample spatially. + if self.spatial_down: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x1 = self.conv1(x) + x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + x = x1 + x2 + + # hybrid downsample temporally. + if self.temporal_down: + x = replication_pad(x) + x1 = self.conv2(x) + x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1)) + x = x1 + x2 + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalResnetBlock3d(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + num_groups: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=num_groups) + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalResnetBlockFactorized3d(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + num_groups: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=1) + self.conv1 = nn.Sequential( + CausalConv3d( + in_channels, + out_channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Sequential( + CausalConv3d( + out_channels, + out_channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size = time2batch(q) + k, batch_size = time2batch(k) + v, batch_size = time2batch(v) + + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = batch2time(h_, batch_size) + h_ = self.proj_out(h_) + return x + h_ + + +class CausalTemporalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size, height = space2batch(q) + k, _, _ = space2batch(k) + v, _, _ = space2batch(v) + + bhw, c, t = q.shape + q = q.permute(0, 2, 1) # (bhw, t, c) + k = k.permute(0, 2, 1) # (bhw, t, c) + v = v.permute(0, 2, 1) # (bhw, t, c) + + w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t) + w_ = w_ * (int(c) ** (-0.5)) + + # Apply causal mask + mask = torch.tril(torch.ones_like(w_)) + w_ = w_.masked_fill(mask == 0, float("-inf")) + w_ = F.softmax(w_, dim=2) + + # attend to values + h_ = torch.bmm(w_, v) # (bhw, t, c) + h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t) + + h_ = batch2space(h_, batch_size, height) + h_ = self.proj_out(h_) + return x + h_ + + +class EncoderBase(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size + + # downsampling + self.conv_in = CausalConv3d(in_channels, channels, kernel_size=3, stride=1, padding=1) + + # num of groups for GroupNorm, num_groups=1 for LayerNorm. + num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=num_groups, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = CausalDownsample3d(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) + self.mid.block_2 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=num_groups) + self.conv_out = CausalConv3d(block_in, z_channels, kernel_size=3, stride=1, padding=1) + + def patcher3d(self, x: torch.Tensor) -> torch.Tensor: + x, batch_size = time2batch(x) + x = self.patcher(x) + x = batch2time(x, batch_size) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + else: + # temporal downsample (last level) + time_factor = 1 + 1 * (hs[-1].shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + hs[-1] = replication_pad(hs[-1]) + hs.append( + F.avg_pool3d( + hs[-1], + kernel_size=[time_factor, 1, 1], + stride=[2, 1, 1], + ) + ) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderBase(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # num of groups for GroupNorm, num_groups=1 for LayerNorm. + num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) + self.mid.block_2 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=num_groups, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = CausalUpsample3d(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=num_groups) + self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def unpatcher3d(self, x: torch.Tensor) -> torch.Tensor: + x, batch_size = time2batch(x) + x = self.unpatcher(x) + x = batch2time(x, batch_size) + + return x + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + else: + # temporal upsample (last level) + time_factor = 1.0 + 1.0 * (h.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + h = h.repeat_interleave(int(time_factor), dim=2) + h = h[..., int(time_factor - 1) :, :, :] + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h + + +class EncoderFactorized(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int = 16, + temporal_compression: int = 8, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size * patch_size + + # calculate the number of downsample operations + self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert ( + self.num_spatial_downs <= self.num_resolutions + ), f"Spatially downsample {self.num_resolutions} times at most" + + self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_downs <= self.num_resolutions + ), f"Temporally downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = nn.Sequential( + CausalConv3d( + in_channels, + channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=1, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + ) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + spatial_down = i_level < self.num_spatial_downs + temporal_down = i_level < self.num_temporal_downs + down.downsample = CausalHybridDownsample3d( + block_in, + spatial_down=spatial_down, + temporal_down=temporal_down, + ) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d( + z_channels, + z_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderFactorized(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int = 16, + temporal_compression: int = 8, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size * patch_size + + # calculate the number of upsample operations + self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most" + self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_ups <= self.num_resolutions + ), f"Temporally upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = nn.Sequential( + CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + + legacy_mode = ignore_kwargs.get("legacy_mode", False) + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=1, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + ) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + # The layer index for temporal/spatial downsampling performed + # in the encoder should correspond to the layer index in + # reverse order where upsampling is performed in the decoder. + # If you've a pre-trained model, you can simply finetune. + i_level_reverse = self.num_resolutions - i_level - 1 + if legacy_mode: + temporal_up = i_level_reverse < self.num_temporal_ups + else: + temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1 + spatial_up = temporal_up or ( + i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups + ) + up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h diff --git a/cosmos_predict1/tokenizer/modules/layers3d_test.py b/cosmos_predict1/tokenizer/modules/layers3d_test.py new file mode 100644 index 0000000000000000000000000000000000000000..711e279afa811123485196af2355dd6bc6b27a36 --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/layers3d_test.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The test for model definition of 3D layers + +PYTHONPATH=$PWD pytest -v cosmos_predict1/tokenizer/modules/layers3d_test.py +""" +import os + +import numpy as np +import pytest +import torch +from torchvision.transforms import CenterCrop + +from cosmos_predict1.tokenizer.inference.utils import read_video +from cosmos_predict1.tokenizer.inference.video_lib import CausalVideoTokenizer +from cosmos_predict1.tokenizer.networks import TokenizerConfigs + +# test configs +TEST_CONFIGS = [ + ("CV8x8x8-720p", "checkpoints/Cosmos-Tokenize1-CV8x8x8-720p"), + ("DV8x16x16-720p", "checkpoints/Cosmos-Tokenize1-DV8x16x16-720p"), + ("CV4x8x8-360p", "checkpoints/Cosmos-Tokenize1-CV4x8x8-360p"), + ("DV4x8x8-360p", "checkpoints/Cosmos-Tokenize1-DV4x8x8-360p"), +] + + +@pytest.fixture(scope="module") +def video_tensor(): + video_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "test_data", "video.mp4") + print(f"video_path: {video_path}") + video = read_video(video_path) + + assert video.shape[0] >= 17, "Video length should be at least 17 frames" + assert video.shape[1] >= 512, "Video height should be at least 512 pixels" + assert video.shape[2] >= 512, "Video width should be at least 512 pixels" + assert video.shape[3] == 3, "Video should have 3 channels" + + input_tensor = CenterCrop(512)( + torch.from_numpy(video[np.newaxis, ...])[:, :17].to("cuda").to(torch.bfloat16).permute(0, 4, 1, 2, 3) + / 255.0 + * 2.0 + - 1.0 + ) + return input_tensor + + +@pytest.mark.parametrize("config", TEST_CONFIGS) +def test_tokenizer(config, video_tensor): + name, model_id = config + continuous = name.startswith(("C", "c")) + temporal_compression, spatial_compression = list(map(int, name[2:].split("x")[:2])) + print(f"\nTesting tokenizer: {model_id}") + print(f"temporal_compression={temporal_compression}") + print(f"spatial_compression={spatial_compression}") + print(f"checkpoint_enc=checkpoints/{os.path.basename(model_id)}/encoder.jit") + print(f"checkpoint_dec=checkpoints/{os.path.basename(model_id)}/decoder.jit") + + _config = TokenizerConfigs[name.replace("-", "_")].value + autoencoder = CausalVideoTokenizer( + checkpoint_enc=f"checkpoints/{os.path.basename(model_id)}/encoder.jit", + checkpoint_dec=f"checkpoints/{os.path.basename(model_id)}/decoder.jit", + tokenizer_config=_config, + device="cuda", + dtype="bfloat16", + ) + + try: + # Test shape check + reconstructed_tensor = auto_shape_check( + video_tensor, autoencoder, temporal_compression, spatial_compression, continuous + ) + finally: + # Cleanup + del autoencoder + del reconstructed_tensor + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def auto_shape_check(input_tensor, autoencoder, temporal_compression, spatial_compression, continuous): + if continuous: + (latent,) = autoencoder.encode(input_tensor) + torch.testing.assert_close( + latent.shape, + (1, 16, (17 - 1) // temporal_compression + 1, 512 // spatial_compression, 512 // spatial_compression), + ) + reconstructed_tensor = autoencoder.decode(latent) + else: + (indices, codes) = autoencoder.encode(input_tensor) + torch.testing.assert_close( + indices.shape, + (1, (17 - 1) // temporal_compression + 1, 512 // spatial_compression, 512 // spatial_compression), + ) + torch.testing.assert_close( + codes.shape, + (1, 6, (17 - 1) // temporal_compression + 1, 512 // spatial_compression, 512 // spatial_compression), + ) + reconstructed_tensor = autoencoder.decode(indices) + + torch.testing.assert_close(reconstructed_tensor.shape, input_tensor.shape) + return reconstructed_tensor diff --git a/cosmos_predict1/tokenizer/modules/patching.py b/cosmos_predict1/tokenizer/modules/patching.py new file mode 100644 index 0000000000000000000000000000000000000000..028df019cdf9bf1126682e144f09c107da834908 --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/patching.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The patcher and unpatcher implementation for 2D and 3D data. + +The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions. +One on the rows and one on the columns. +For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2. +We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component. +For H component, we can use a 1D convolution with kernel [1, -1] and stride 2. +Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all + as we need to support downsampling for more than 2x. +For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be. + [3, 256, 256] -> [12, 128, 128] -> [48, 64, 64] +""" + +import torch +import torch.nn.functional as F +from einops import rearrange + +_WAVELETS = { + "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]), + "rearrange": torch.tensor([1.0, 1.0]), +} +_PERSISTENT = True + + +class Patcher(torch.nn.Module): + """A module to convert image tensors into patches using torch operations. + + The main difference from `class Patching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Patching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer( + "_arange", + torch.arange(_WAVELETS[patch_method].shape[0]), + persistent=_PERSISTENT, + ) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._haar(x) + elif self.patch_method == "rearrange": + return self._arrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) + xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) + xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) + xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) + + out = torch.cat([xll, xlh, xhl, xhh], dim=1) + if rescale: + out = out / 2 + return out + + def _haar(self, x): + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + x = rearrange( + x, + "b c (h p1) (w p2) -> b (c p1 p2) h w", + p1=self.patch_size, + p2=self.patch_size, + ).contiguous() + return x + + +class Patcher3D(Patcher): + """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + self.register_buffer( + "patch_size_buffer", + patch_size * torch.ones([1], dtype=torch.int32), + persistent=_PERSISTENT, + ) + + def _dwt(self, x, wavelet, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + # Handles temporal axis. + x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + # Handles spatial axes. + xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1) + if rescale: + out = out / (2 * torch.sqrt(torch.tensor(2.0))) + return out + + def _haar(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + for _ in self.range: + x = self._dwt(x, "haar", rescale=True) + return x + + def _arrange(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + x = rearrange( + x, + "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ).contiguous() + return x + + +class UnPatcher(torch.nn.Module): + """A module to convert patches into image tensorsusing torch operations. + + The main difference from `class Unpatching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Unpatching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer( + "_arange", + torch.arange(_WAVELETS[patch_method].shape[0]), + persistent=_PERSISTENT, + ) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._ihaar(x) + elif self.patch_method == "rearrange": + return self._iarrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 4 + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) + + # Inverse transform. + yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + + if rescale: + y = y * 2 + return y + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, "haar", rescale=True) + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.patch_size, + p2=self.patch_size, + ) + return x + + +class UnPatcher3D(UnPatcher): + """A 3D inverse discrete wavelet transform for video wavelet decompositions.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + + def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors. + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hl = hl.to(dtype=dtype) + hh = hh.to(dtype=dtype) + + xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + + # Height height transposed convolutions. + xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + # Handles width transposed convolutions. + xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + # Handles time axis transposed convolutions. + x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + if rescale: + x = x * (2 * torch.sqrt(torch.tensor(2.0))) + return x + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, "haar", rescale=True) + x = x[:, :, self.patch_size - 1 :, ...] + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ) + x = x[:, :, self.patch_size - 1 :, ...] + return x diff --git a/cosmos_predict1/tokenizer/modules/quantizers.py b/cosmos_predict1/tokenizer/modules/quantizers.py new file mode 100644 index 0000000000000000000000000000000000000000..70a1e0c95c9e1143eb278d1c4a554073939af437 --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/quantizers.py @@ -0,0 +1,507 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantizers for discrete image and video tokenization.""" + +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import reduce +from loguru import logger as logging + +from cosmos_predict1.tokenizer.modules.utils import default, entropy, pack_one, rearrange, round_ste, unpack_one + +_PERSISTENT = True + + +class ResidualFSQuantizer(nn.Module): + """Residual Finite Scalar Quantization + + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, levels: list[int], num_quantizers: int, **ignore_kwargs): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.float32) + self.layers = nn.ModuleList([FSQuantizer(levels=levels) for _ in range(num_quantizers)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + indices_stack = [] + residual = x + quantized_out = 0 + loss_out = 0 + for i, layer in enumerate(self.layers): + quant_indices, z, loss = layer(residual) + indices_stack.append(quant_indices) + residual = residual - z.detach() + quantized_out = quantized_out + z + loss_out = loss_out + loss + self.residual = residual + indices = torch.stack(indices_stack, dim=1) + return indices, quantized_out.to(self.dtype), loss_out.to(self.dtype) + + def indices_to_codes(self, indices_stack: torch.Tensor) -> torch.Tensor: + quantized_out = 0 + for layer, indices in zip(self.layers, indices_stack.transpose(0, 1)): + quantized_out += layer.indices_to_codes(indices) + return quantized_out + + +class FSQuantizer(nn.Module): + """Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 + + Code adapted from Jax version in Appendix A.1. + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/finite_scalar_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + levels: list[int], + dim: Optional[int] = None, + num_codebooks=1, + keep_num_codebooks_dim: Optional[bool] = None, + scale: Optional[float] = None, + **ignore_kwargs, + ): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.bfloat16) + self.persistent = ignore_kwargs.get("persistent_quantizer", _PERSISTENT) + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("_levels", _levels, persistent=self.persistent) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32) + self.register_buffer("_basis", _basis, persistent=self.persistent) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + has_projections = self.dim != effective_codebook_dim + self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.codebook_size = self._levels.prod().item() + + implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) + self.register_buffer("implicit_codebook", implicit_codebook, persistent=self.persistent) + + def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat).float() + return (zhat * self._basis).sum(dim=-1).to(torch.int32) + + def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor: + """Inverse of `codes_to_indices`.""" + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + indices = rearrange(indices, "... -> ... 1") + codes_non_centered = (indices // self._basis) % self._levels + codes = self._scale_and_shift_inverse(codes_non_centered) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, "... c d -> ... (c d)") + + if project_out: + codes = self.project_out(codes) + + if is_img_or_video: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes.to(self.dtype) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + is_img_or_video = z.ndim >= 4 + + # standardize image or video into (batch, seq, dimension) + + if is_img_or_video: + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + + assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + + z = self.project_in(z) + + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + codes = self.quantize(z) + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, "b n c d -> b n (c d)") + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if is_img_or_video: + out = unpack_one(out, ps, "b * d") + out = rearrange(out, "b ... d -> b d ...") + indices = unpack_one(indices, ps, "b * c") + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True)) + else: + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1) + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... 1 -> ...") + + return (indices, out.to(self.dtype), dummy_loss) + + +class VectorQuantizer(nn.Module): + """Improved version over VectorQuantizer. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + + Adapted from: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/ + taming/modules/vqvae/quantize.py + + [Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer] + https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/License.txt + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + beta: float = 0.25, + remap: str = None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, + use_norm=False, + **ignore_kwargs, + ): + super().__init__() + self.n_e = num_embeddings + self.e_dim = embedding_dim + self.beta = beta + self.legacy = legacy + self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = num_embeddings + + self.sane_index_shape = sane_index_shape + self.dtype = ignore_kwargs.get("dtype", torch.float32) + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits is False, "Only for interface compatible with Gumbel" + assert return_logits is False, "Only for interface compatible with Gumbel" + z = rearrange(z, "b c h w -> b h w c").contiguous() + z_flattened = z.view(-1, self.e_dim) + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", + z_flattened, + rearrange(self.embedding.weight, "n d -> d n"), + ) + ) + + encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self.n_e, device=z.device) + encodings.scatter_(1, encoding_indices, 1) + z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape) + min_encodings = None + + z_q, z = self.norm(z_q), self.norm(z) + + # compute loss for embedding + commit_loss = torch.mean((z_q - z.detach()) ** 2, dim=[1, 2, 3], keepdim=True) + emb_loss = torch.mean((z_q.detach() - z) ** 2, dim=[1, 2, 3], keepdim=True) + if not self.legacy: + loss = self.beta * emb_loss + commit_loss + else: + loss = emb_loss + self.beta * commit_loss + + # preserve gradients + z_q = z + (z_q - z).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + # reshape back to match original input shape + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + + if self.remap is not None: + min_encoding_indices = encoding_indices.squeeze(1).reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(encoding_indices.squeeze(1)) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return ( + z_q, + loss, + ( + encoding_indices.squeeze(1), + min_encodings, + commit_loss.mean().detach(), + self.beta * emb_loss.mean().detach(), + perplexity.mean().detach(), + ), + ) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class LFQuantizer(nn.Module): + """Lookup-Free Quantization + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/lookup_free_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + *, + codebook_size: int, + codebook_dim: int, + embed_dim: Optional[int] = None, # if None, use codebook_dim + entropy_loss_weight=0.1, + commitment_loss_weight=0.25, + default_temp: float = 0.01, + entropy_loss: bool = False, + **ignore_kwargs, + ): + """Lookup-Free Quantization + + Args: + codebook_size (int): The number of entries in the codebook. + codebook_dim (int): The number of bits in each code. + embed_dim (Optional[int], optional): The dimension of the input embedding. Defaults to None. + entropy_loss_weight (float, optional): Whether to use entropy loss. Defaults to 0.1. + commitment_loss_weight (float, optional): Weight for commitment loss. Defaults to 0.25. + default_temp (float, optional): The temprature to use. Defaults to 0.01. + entropy_loss (bool, optional): Flag for entropy loss. Defaults to False. + """ + super().__init__() + self.entropy_loss = entropy_loss + self.codebook_dim = codebook_dim + self.default_temp = default_temp + self.entrop_loss_weight = entropy_loss_weight + self.commitment_loss_weight = commitment_loss_weight + embed_dim = embed_dim or codebook_dim + + has_projections = embed_dim != codebook_dim + self.project_in = nn.Linear(embed_dim, codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(codebook_dim, embed_dim) if has_projections else nn.Identity() + logging.info(f"LFQ: has_projections={has_projections}, dim_in={embed_dim}, codebook_dim={codebook_dim}") + + self.dtype = ignore_kwargs.get("dtype", torch.float32) + + if entropy_loss: + assert 2**codebook_dim == codebook_size, "codebook size must be 2 ** codebook_dim" + self.codebook_size = codebook_size + + self.register_buffer( + "mask", + 2 ** torch.arange(codebook_dim - 1, -1, -1), + persistent=_PERSISTENT, + ) + self.register_buffer("zero", torch.tensor(0.0), persistent=_PERSISTENT) + + all_codes = torch.arange(codebook_size) + bits = ((all_codes[..., None].int() & self.mask) != 0).float() + codebook = 2 * bits - 1.0 + + self.register_buffer("codebook", codebook, persistent=_PERSISTENT) # [codebook_size, codebook_dim] + + def forward(self, z: torch.Tensor, temp: float = None) -> torch.Tensor: + temp = temp or self.default_temp + + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + z = self.project_in(z) + + # split out number of codebooks + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + # quantization + original_input = z + + codebook_value = torch.ones_like(z) + z_q = torch.where(z > 0, codebook_value, -codebook_value) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # commit loss + commit_loss = ((original_input - z_q.detach()) ** 2).mean(dim=[1, 2, 3]) + + z_q = rearrange(z_q, "b n c d -> b n (c d)") + z_q = self.project_out(z_q) + + # reshape + z_q = unpack_one(z_q, ps, "b * d") + z_q = rearrange(z_q, "b ... d -> b d ...") + + loss = self.commitment_loss_weight * commit_loss + + # entropy loss (eq-5) + if self.entropy_loss: + # indices + indices = reduce((z > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") + indices = unpack_one(indices, ps, "b * c") + indices = rearrange(indices, "... 1 -> ...") + + distance = -2 * torch.einsum( + "... i d, j d -> ... i j", + original_input, + self.codebook.to(original_input.dtype), + ) + prob = (-distance / temp).softmax(dim=-1) + per_sample_entropy = entropy(prob).mean(dim=[1, 2]) + avg_prob = reduce(prob, "... c d -> c d", "mean") + codebook_entropy = entropy(avg_prob).mean() + entropy_aux_loss = per_sample_entropy - codebook_entropy + + loss += self.entrop_loss_weight * entropy_aux_loss + + return ( + z_q, + loss.unsqueeze(1).unsqueeze(1).unsqueeze(1), + ( + indices, + self.commitment_loss_weight * commit_loss.mean().detach(), + self.entrop_loss_weight * entropy_aux_loss.mean().detach(), + self.entrop_loss_weight * per_sample_entropy.mean().detach(), + self.entrop_loss_weight * codebook_entropy.mean().detach(), + ), + ) + else: + return ( + z_q, + loss.unsqueeze(1).unsqueeze(1).unsqueeze(1), + self.commitment_loss_weight * commit_loss.mean().detach(), + ) + + +class InvQuantizerJit(nn.Module): + """Use for decoder_jit to trace quantizer in discrete tokenizer""" + + def __init__(self, quantizer): + super().__init__() + self.quantizer = quantizer + + def forward(self, indices: torch.Tensor): + codes = self.quantizer.indices_to_codes(indices) + return codes.to(self.quantizer.dtype) diff --git a/cosmos_predict1/tokenizer/modules/utils.py b/cosmos_predict1/tokenizer/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..578bf2fa3f15e0dbe05054d30fda380f6d93e53f --- /dev/null +++ b/cosmos_predict1/tokenizer/modules/utils.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for the networks module.""" + +from typing import Any + +import torch +from einops import pack, rearrange, unpack + + +def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size = x.shape[0] + return rearrange(x, "b c t h w -> (b t) c h w"), batch_size + + +def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor: + return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + + +def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size, height = x.shape[0], x.shape[-2] + return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height + + +def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor: + return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height) + + +def cast_tuple(t: Any, length: int = 1) -> Any: + return t if isinstance(t, tuple) else ((t,) * length) + + +def replication_pad(x): + return torch.cat([x[:, :, :1, ...], x], dim=2) + + +def divisible_by(num: int, den: int) -> bool: + return (num % den) == 0 + + +def is_odd(n: int) -> bool: + return not divisible_by(n, 2) + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class CausalNormalize(torch.nn.Module): + def __init__(self, in_channels, num_groups=1): + super().__init__() + self.norm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True, + ) + self.num_groups = num_groups + + def forward(self, x): + # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose. + # All new models should use num_groups=1, otherwise causality is not guaranteed. + if self.num_groups == 1: + x, batch_size = time2batch(x) + return batch2time(self.norm(x), batch_size) + return self.norm(x) + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def round_ste(z: torch.Tensor) -> torch.Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() + + +def entropy(prob): + return (-prob * log(prob)).sum(dim=-1) diff --git a/cosmos_predict1/tokenizer/networks/__init__.py b/cosmos_predict1/tokenizer/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b820aba6f474ec1ebe798b82b5f41362e0c43a4f --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/__init__.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +from cosmos_predict1.tokenizer.networks.configs import continuous_image_8x8_360p as continuous_image_8x8_360p_dict +from cosmos_predict1.tokenizer.networks.configs import continuous_image_16x16_360p as continuous_image_16x16_360p_dict +from cosmos_predict1.tokenizer.networks.configs import continuous_video_4x8x8_360p as continuous_video_4x8x8_360p_dict +from cosmos_predict1.tokenizer.networks.configs import continuous_video_8x8x8_720p as continuous_video_8x8x8_720p_dict +from cosmos_predict1.tokenizer.networks.configs import discrete_image_8x8_360p as discrete_image_8x8_360p_dict +from cosmos_predict1.tokenizer.networks.configs import discrete_image_16x16_360p as discrete_image_16x16_360p_dict +from cosmos_predict1.tokenizer.networks.configs import discrete_video_4x8x8_360p as discrete_video_4x8x8_360p_dict +from cosmos_predict1.tokenizer.networks.configs import discrete_video_8x16x16_720p as discrete_video_8x16x16_720p_dict +from cosmos_predict1.tokenizer.networks.continuous_image import ContinuousImageTokenizer +from cosmos_predict1.tokenizer.networks.continuous_video import CausalContinuousVideoTokenizer +from cosmos_predict1.tokenizer.networks.discrete_image import DiscreteImageTokenizer +from cosmos_predict1.tokenizer.networks.discrete_video import CausalDiscreteVideoTokenizer + + +class TokenizerConfigs(Enum): + """Continuous Image (CI) Tokenizer Configs""" + + # Cosmos-Tokenize1-CI8x8-360p + CI8x8_360p = continuous_image_8x8_360p_dict + + # Cosmos-Tokenize1-CI16x16-360p + CI16x16_360p = continuous_image_16x16_360p_dict + + """Discrete Image (DI) Tokenizer Configs""" + # Cosmos-Tokenize1-DI8x8-360p + DI8x8_360p = discrete_image_8x8_360p_dict + + # Cosmos-Tokenize1-DI16x16-360p + DI16x16_360p = discrete_image_16x16_360p_dict + + """Causal Continuous Video (CV) Tokenizer Configs""" + # Cosmos-Tokenize1-CV8x8x8-720p + CV8x8x8_720p = continuous_video_8x8x8_720p_dict + + # Cosmos-Tokenize1-CV4x8x8-360p + CV4x8x8_360p = continuous_video_4x8x8_360p_dict + + """Causal Discrete Video (DV) Tokenizer Configs""" + # Cosmos-Tokenize1-DV8x16x16-720p + DV8x16x16_720p = discrete_video_8x16x16_720p_dict + + # Cosmos-Tokenize1-DV4x8x8-360p + DV4x8x8_360p = discrete_video_4x8x8_360p_dict + + +class TokenizerModels(Enum): + CI = ContinuousImageTokenizer + DI = DiscreteImageTokenizer + CV = CausalContinuousVideoTokenizer + DV = CausalDiscreteVideoTokenizer diff --git a/cosmos_predict1/tokenizer/networks/configs.py b/cosmos_predict1/tokenizer/networks/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..540e04cefcfd08d1325c168830d5990f44eb760b --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/configs.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The default image and video tokenizer configs.""" + +from cosmos_predict1.tokenizer.modules import ( + ContinuousFormulation, + Decoder3DType, + DecoderType, + DiscreteQuantizer, + Encoder3DType, + EncoderType, +) + +continuous_image = dict( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio. + spatial_compression=16, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The output latent dimension (channels). + latent_channels=16, + # The encoder output channels just before sampling. + # Which is also the decoder's input channels. + z_channels=16, + # A factor over the z_channels, to get the total channels the encoder should output. + # For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels. + z_factor=1, + name="CI", + # What formulation to use, either "AE" or "VAE". + # Chose VAE here, since the pre-trained ckpt were of a VAE formulation. + formulation=ContinuousFormulation.AE.name, + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) +continuous_image_8x8_360p = dict(continuous_image) +continuous_image_8x8_360p["patch_size"] = 2 +continuous_image_8x8_360p["spatial_compression"] = 8 + +continuous_image_16x16_360p = dict(continuous_image) +continuous_image_16x16_360p["patch_size"] = 2 +continuous_image_16x16_360p["spatial_compression"] = 16 + + +discrete_image = dict( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio. + spatial_compression=16, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The encoder output channels just before sampling. + z_channels=256, + # A factor over the z_channels, to get the total channels the encoder should output. + # for discrete tokenization, often we directly use the vector, so z_factor=1. + z_factor=1, + # The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ. + quantizer=DiscreteQuantizer.FSQ.name, + # The embedding dimension post-quantization, which is also the input channels of the decoder. + # Which is also the output + embedding_dim=6, + # The number of levels to use for fine-scalar quantization. + levels=[8, 8, 8, 5, 5, 5], + # The number of quantizers to use for residual fine-scalar quantization. + num_quantizers=4, + name="DI", + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) +discrete_image_8x8_360p = dict(discrete_image) +discrete_image_8x8_360p["patch_size"] = 2 +discrete_image_8x8_360p["spatial_compression"] = 8 + +discrete_image_16x16_360p = dict(discrete_image) +discrete_image_16x16_360p["patch_size"] = 2 +discrete_image_16x16_360p["spatial_compression"] = 16 + +continuous_video = dict( + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + latent_channels=16, + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=8, + temporal_compression=8, + formulation=ContinuousFormulation.AE.name, + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="CV", +) + +continuous_video_8x8x8_720p = dict(continuous_video) +continuous_video_8x8x8_720p["temporal_compression"] = 8 +continuous_video_8x8x8_720p["spatial_compression"] = 8 + +continuous_video_4x8x8_360p = dict(continuous_video) +continuous_video_4x8x8_360p["temporal_compression"] = 4 +continuous_video_4x8x8_360p["spatial_compression"] = 8 +continuous_video_4x8x8_360p["patch_size"] = 2 + + +discrete_video = dict( + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=16, + temporal_compression=8, + quantizer=DiscreteQuantizer.FSQ.name, + embedding_dim=6, + levels=[8, 8, 8, 5, 5, 5], + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="DV", +) + +discrete_video_8x16x16_720p = dict(discrete_video) +discrete_video_8x16x16_720p["temporal_compression"] = 8 +discrete_video_8x16x16_720p["spatial_compression"] = 16 + +discrete_video_4x8x8_360p = dict(discrete_video) +discrete_video_4x8x8_360p["z_channels"] = 256 +discrete_video_4x8x8_360p["temporal_compression"] = 4 +discrete_video_4x8x8_360p["spatial_compression"] = 8 +discrete_video_4x8x8_360p["patch_size"] = 2 diff --git a/cosmos_predict1/tokenizer/networks/continuous_image.py b/cosmos_predict1/tokenizer/networks/continuous_image.py new file mode 100644 index 0000000000000000000000000000000000000000..c7e288d2f39bf69a0895f97b63a3ef3e04fbcdb6 --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/continuous_image.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The continuous image tokenizer with VAE or AE formulation for 2D data.""" + +from collections import OrderedDict, namedtuple + +import torch +from loguru import logger as logging +from torch import nn + +from cosmos_predict1.tokenizer.modules import ContinuousFormulation, DecoderType, EncoderType + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) + + +class ContinuousImageTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, latent_channels: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "ContinuousImageTokenizer") + self.latent_channels = latent_channels + + encoder_name = kwargs.get("encoder", EncoderType.Default.name) + self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) + + decoder_name = kwargs.get("decoder", DecoderType.Default.name) + self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) + + self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, z_channels, 1) + + formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) + self.distribution = ContinuousFormulation[formulation_name].value() + logging.info(f"{self.name} based on {formulation_name} formulation, with {kwargs}.") + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info(f"z_channels={z_channels}, latent_channels={self.latent_channels}.") + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("distribution", self.distribution), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + return self.distribution(moments) + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input) -> dict[str, torch.Tensor] | NetworkEval: + latent, posteriors = self.encode(input) + dec = self.decode(latent) + if self.training: + return dict(reconstructions=dec, posteriors=posteriors, latent=latent) + return NetworkEval(reconstructions=dec, posteriors=posteriors, latent=latent) diff --git a/cosmos_predict1/tokenizer/networks/continuous_video.py b/cosmos_predict1/tokenizer/networks/continuous_video.py new file mode 100644 index 0000000000000000000000000000000000000000..c054427d9c4a20f0f0d5e606ddb62e68cadfbeb3 --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/continuous_video.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The causal continuous video tokenizer with VAE or AE formulation for 3D data..""" +from collections import OrderedDict, namedtuple + +from loguru import logger as logging +from torch import nn + +from cosmos_predict1.tokenizer.modules import ContinuousFormulation, Decoder3DType, Encoder3DType +from cosmos_predict1.tokenizer.modules.layers3d import CausalConv3d + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) + + +class CausalContinuousVideoTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, latent_channels: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "CausalContinuousVideoTokenizer") + self.latent_channels = latent_channels + + encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) + self.encoder = Encoder3DType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) + decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) + self.decoder = Decoder3DType[decoder_name].value(z_channels=z_channels, **kwargs) + + self.quant_conv = CausalConv3d( + z_factor * z_channels, + z_factor * latent_channels, + kernel_size=1, + padding=0, + ) + self.post_quant_conv = CausalConv3d(latent_channels, z_channels, kernel_size=1, padding=0) + + formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) + self.distribution = ContinuousFormulation[formulation_name].value() + logging.info(f"{self.name} based on {formulation_name} formulation, with {kwargs}.") + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info(f"z_channels={z_channels}, latent_channels={self.latent_channels}.") + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("distribution", self.distribution), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + return self.distribution(moments) + + def decode(self, z): + z = self.post_quant_conv(z) + return self.decoder(z) + + def forward(self, input): + latent, posteriors = self.encode(input) + reconstructions = self.decode(latent) + if self.training: + return dict( + reconstructions=reconstructions, + posteriors=posteriors, + latent=latent, + ) + return NetworkEval( + reconstructions=reconstructions, + posteriors=posteriors, + latent=latent, + ) diff --git a/cosmos_predict1/tokenizer/networks/discrete_image.py b/cosmos_predict1/tokenizer/networks/discrete_image.py new file mode 100644 index 0000000000000000000000000000000000000000..02b160b43028912ea7b17cd38a0e98375aa839a7 --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/discrete_image.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The network definition for discrete image tokenization with VQ, LFQ, FSQ or ResidualFSQ.""" +from collections import OrderedDict, namedtuple + +import torch +from loguru import logger as logging +from torch import nn + +from cosmos_predict1.tokenizer.modules import DecoderType, DiscreteQuantizer, EncoderType +from cosmos_predict1.tokenizer.modules.quantizers import InvQuantizerJit + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) + + +class DiscreteImageTokenizer(nn.Module): + def __init__(self, z_channels: int, embedding_dim: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "DiscreteImageTokenizer") + self.embedding_dim = embedding_dim + + encoder_name = kwargs.get("encoder", EncoderType.Default.name) + self.encoder = EncoderType[encoder_name].value(z_channels=z_channels, **kwargs) + + decoder_name = kwargs.get("decoder", DecoderType.Default.name) + self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) + self.quant_conv = nn.Conv2d(z_channels, embedding_dim, 1) + self.post_quant_conv = nn.Conv2d(embedding_dim, z_channels, 1) + + quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) + if quantizer_name == DiscreteQuantizer.VQ.name: + assert "num_embeddings" in kwargs, f"`num_embeddings` must be provided for {quantizer_name}." + kwargs.update(dict(embedding_dim=embedding_dim)) + elif quantizer_name == DiscreteQuantizer.LFQ.name: + assert "codebook_size" in kwargs, f"`codebook_size` must be provided for {quantizer_name}." + assert "codebook_dim" in kwargs, f"`codebook_dim` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.FSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.RESFSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}.name." + assert "num_quantizers" in kwargs, f"`num_quantizers` must be provided for {quantizer_name}." + self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) + logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.") + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") + + def to(self, *args, **kwargs): + setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) + return super(DiscreteImageTokenizer, self).to(*args, **kwargs) + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("quantizer", self.quantizer), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("inv_quant", InvQuantizerJit(self.quantizer)), + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantizer(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + return self.decoder(quant) + + def decode_code(self, code_b): + quant_b = self.quantizer.indices_to_codes(code_b) + quant_b = self.post_quant_conv(quant_b) + return self.decoder(quant_b) + + def forward(self, input): + quant_info, quant_codes, quant_loss = self.encode(input) + reconstructions = self.decode(quant_codes) + if self.training: + return dict( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) + return NetworkEval( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) diff --git a/cosmos_predict1/tokenizer/networks/discrete_video.py b/cosmos_predict1/tokenizer/networks/discrete_video.py new file mode 100644 index 0000000000000000000000000000000000000000..db1aea39103f7e4bc92db10858890dd3b54b1e4b --- /dev/null +++ b/cosmos_predict1/tokenizer/networks/discrete_video.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The network definition for discrete video tokenizer with VQ, LFQ, FSQ or ResidualFSQ. """ +from collections import OrderedDict, namedtuple + +import torch +from loguru import logger as logging +from torch import nn + +from cosmos_predict1.tokenizer.modules import Decoder3DType, DiscreteQuantizer, Encoder3DType +from cosmos_predict1.tokenizer.modules.layers3d import CausalConv3d +from cosmos_predict1.tokenizer.modules.quantizers import InvQuantizerJit + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) + + +class CausalDiscreteVideoTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer") + self.embedding_dim = embedding_dim + + encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) + self.encoder = Encoder3DType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) + + decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) + self.decoder = Decoder3DType[decoder_name].value(z_channels=z_channels, **kwargs) + + self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) + self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) + + quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) + if quantizer_name == DiscreteQuantizer.VQ.name: + assert "num_embeddings" in kwargs, f"`num_embeddings` must be provided for {quantizer_name}." + kwargs.update(dict(embedding_dim=embedding_dim)) + elif quantizer_name == DiscreteQuantizer.LFQ.name: + assert "codebook_size" in kwargs, f"`codebook_size` must be provided for {quantizer_name}." + assert "codebook_dim" in kwargs, f"`codebook_dim` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.FSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.RESFSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." + assert "num_quantizers" in kwargs, f"`num_quantizers` must be provided for {quantizer_name}." + self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) + logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.") + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") + + def to(self, *args, **kwargs): + setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) + return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs) + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("quantizer", self.quantizer), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("inv_quant", InvQuantizerJit(self.quantizer)), + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantizer(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + return self.decoder(quant) + + def decode_code(self, code_b): + quant_b = self.quantizer.indices_to_codes(code_b) + quant_b = self.post_quant_conv(quant_b) + return self.decoder(quant_b) + + def forward(self, input): + quant_info, quant_codes, quant_loss = self.encode(input) + reconstructions = self.decode(quant_codes) + if self.training: + return dict( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) + return NetworkEval( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) diff --git a/cosmos_predict1/tokenizer/notebook/Image_Tokenization.ipynb b/cosmos_predict1/tokenizer/notebook/Image_Tokenization.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..75b3dcac46e5247fb646badad7e8a305b8bd50f2 --- /dev/null +++ b/cosmos_predict1/tokenizer/notebook/Image_Tokenization.ipynb @@ -0,0 +1,250 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "n3ryhkSfIEfl" + }, + "source": [ + "# Image Tokenization Using [NVIDIA Cosmos Tokenizer](https://github.com/NVIDIA-Cosmos/cosmos-predict1/blob/main/cosmos1/models/tokenizer) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nvidia-cosmos/cosmos-predict1/blob/main/cosmos_predict1/models/tokenizer/notebook/Image_Tokenization.ipynb)\n", + "\n", + "The Jupyter Notebook example utilizes the **Cosmos-Tokenizer** pretrained models, which include Continuous Image (CI) tokenizers that transform images into continuous latents and Discrete Image (DI) tokenizers that transform images into discrete tokens. Both CI and DI tokenizers are available with compression rates of 8x8 and 16x16. For instance, **CI16x16** effectively downsizes both height and width by a factor of 16.\n", + "\n", + "Within the notebook, the `ImageTokenizer` class from the `cosmos_tokenizer.image_lib` module is employed to manage the encoder and decoder components of this model. The encoder compresses the input image into a condensed latent representation or discrete integers, while the decoder reconstructs the image from this latent representation or discrete integers.\n", + "\n", + "This instance of the Cosmos Tokenizer demonstrates its autoencoding capability: compressing an image into a smaller latent space and subsequently reconstructing it to its original form. This showcases the efficiency of image tokenization for tasks involving significant spatial compression during image reconstruction, a highly desirable feature for generative modeling.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5BkjyLTPLM6e" + }, + "source": [ + "This tutorial follows a simple, step-by-step approach, making it easy to understand and adapt.\n", + "\n", + "## Step 1: Clone the Cosmos Tokenizer Repository" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TEV88M9YG973" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/NVIDIA-Cosmos/cosmos-predict1.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AxOMEJpFL9QL" + }, + "source": [ + "## Step 2: Install **Cosmos-Tokenizer**\n", + "Before proceeding, ensure you have the **Cosmos Tokenizer** installed. If you cloned the repository in Step 1, use the following command to install it in editable mode:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XuwUR6HrIxD8" + }, + "outputs": [], + "source": [ + "# Step 2: # Install Cosmos and its Python dependencies.\n", + "import os\n", + "if os.path.exists(\"cosmos-predict1\"):\n", + " os.chdir(\"cosmos-predict1\")\n", + " %pip install -r requirements.txt\n", + "else:\n", + " print('cosmos-predict1 is already installed.')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "id29RPiyMOtB" + }, + "source": [ + "## Step 3: Set Up Hugging Face API Token and Download Pretrained Models\n", + "\n", + "In this step, you'll configure the Hugging Face API token and download the pretrained model weights required for the **Cosmos Tokenizer**.\n", + "\n", + "1. **Ensure You Have a Hugging Face Account** \n", + " If you do not already have a Hugging Face account, follow these steps to create one and generate an API token:\n", + " - Go to the [Hugging Face website](https://huggingface.co/) and sign up for a free account.\n", + " - After logging in, navigate to your [Settings → Access Tokens](https://huggingface.co/settings/tokens).\n", + " - Click on \"New Token\" to generate an API token with the required permissions.\n", + "\n", + "2. **Set the Hugging Face Token** \n", + " Check if the Hugging Face token is already set in the environment variables. If not, you will be prompted to enter it manually. The token is essential to authenticate and access the Hugging Face models.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "joxcyOlnM7HQ" + }, + "outputs": [], + "source": [ + "# Check if the token is already set\n", + "if \"HUGGINGFACE_TOKEN\" not in os.environ:\n", + " os.environ[\"HUGGINGFACE_TOKEN\"] = input(\"Please enter your Hugging Face API token: \")\n", + "!git config --global credential.helper store" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Lq7MAQ9pGPH9" + }, + "outputs": [], + "source": [ + "from huggingface_hub import login, snapshot_download\n", + "import os\n", + "HUGGINGFACE_TOKEN = os.environ.get(\"HUGGINGFACE_TOKEN\")\n", + "login(token=HUGGINGFACE_TOKEN, add_to_git_credential=True)\n", + "model_names = [\n", + " \"Cosmos-0.1-Tokenizer-CI8x8\",\n", + " \"Cosmos-0.1-Tokenizer-CI16x16\",\n", + " \"Cosmos-0.1-Tokenizer-DI8x8\",\n", + " \"Cosmos-0.1-Tokenizer-DI16x16\",\n", + "]\n", + "for model_name in model_names:\n", + " hf_repo = \"nvidia/\" + model_name\n", + " local_dir = \"checkpoints/\" + model_name\n", + " os.makedirs(local_dir, exist_ok=True)\n", + " print(f\"downloading {model_name}...\")\n", + " snapshot_download(repo_id=hf_repo, local_dir=local_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ltZ-v2vzNv74" + }, + "source": [ + "## Step 4: Use Cosmos Tokenizer for Image Reconstruction\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 839 + }, + "id": "gZFPrGCBGwtC", + "outputId": "0df7efc4-7a40-4011-81a6-3c541ba1601f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input image read from:\t /content/Cosmos-Tokenizer/test_data/image.png\n", + "Reconstruction saved:\t /content/Cosmos-Tokenizer/test_data/image_CI8x8.png\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
Input Image
\n", + "
\n", + "
Reconstructed Image
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# @title In this step, load the required checkpoints, and perform image reconstruction. {\"run\":\"auto\"}\n", + "import cv2\n", + "import numpy as np\n", + "import torch\n", + "\n", + "import importlib\n", + "from cosmos_predict1.tokenizer.inference.image_lib import ImageTokenizer\n", + "import mediapy as media\n", + "\n", + "\n", + "# 1) Specify the model name, and the paths to the encoder/decoder checkpoints.\n", + "model_name = 'Cosmos-0.1-Tokenizer-CI8x8' # @param [\"Cosmos-0.1-Tokenizer-CI16x16\", \"Cosmos-0.1-Tokenizer-CI8x8\", \"Cosmos-0.1-Tokenizer-DI8x8\", \"Cosmos-0.1-Tokenizer-DI16x16\"]\n", + "\n", + "encoder_ckpt = f\"checkpoints/{model_name}/encoder.jit\"\n", + "decoder_ckpt = f\"checkpoints/{model_name}/decoder.jit\"\n", + "\n", + "# 2) Load or provide the image filename you want to tokenize & reconstruct.\n", + "input_filepath = \"cosmos_predict1/tokenizer/test_data/image.png\"\n", + "\n", + "# 3) Read the image from disk (shape = H x W x 3 in BGR). Then convert to RGB.\n", + "input_image = media.read_image(input_filepath)[..., :3]\n", + "assert input_image.ndim == 3 and input_image.shape[2] == 3, \"Image must have shape H x W x 3\"\n", + "\n", + "# 4) Expand dimensions to B x H x W x C, since the ImageTokenizer expects a batch dimension\n", + "# in the input. (Batch size = 1 in this example.)\n", + "batched_input_image = np.expand_dims(input_image, axis=0)\n", + "\n", + "# 5) Create the ImageTokenizer instance with the encoder & decoder.\n", + "# - device=\"cuda\" uses the GPU\n", + "# - dtype=\"bfloat16\" expects Ampere or newer GPU (A100, RTX 30xx, etc.)\n", + "tokenizer = ImageTokenizer(\n", + " checkpoint_enc=encoder_ckpt,\n", + " checkpoint_dec=decoder_ckpt,\n", + " device=\"cuda\",\n", + " dtype=\"bfloat16\",\n", + ")\n", + "\n", + "# 6) Use the tokenizer to autoencode (encode & decode) the image.\n", + "# The output is a NumPy array with shape = B x H x W x C, range [0..255].\n", + "batched_output_image = tokenizer(batched_input_image)\n", + "\n", + "# 7) Extract the single image from the batch (index 0), convert to uint8.\n", + "output_image = batched_output_image[0]\n", + "\n", + "# 9) Save the reconstructed image to disk.\n", + "input_dir, input_filename = os.path.split(input_filepath)\n", + "filename, ext = os.path.splitext(input_filename)\n", + "output_filepath = f\"{input_dir}/{filename}_{model_name.split('-')[-1]}{ext}\"\n", + "media.write_image(output_filepath, output_image)\n", + "print(\"Input image read from:\\t\", f\"{os.getcwd()}/{input_filepath}\")\n", + "print(\"Reconstruction saved:\\t\", f\"{os.getcwd()}/{output_filepath}\")\n", + "\n", + "# 10) Visualization of the input image (left) and the reconstruction (right).\n", + "media.show_images([input_image, output_image], [\"Input Image\", \"Reconstructed Image\"])" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/cosmos_predict1/tokenizer/notebook/Video_Tokenization.ipynb b/cosmos_predict1/tokenizer/notebook/Video_Tokenization.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..31016a779c7069c352e9691a9965cb2f3f5051a5 --- /dev/null +++ b/cosmos_predict1/tokenizer/notebook/Video_Tokenization.ipynb @@ -0,0 +1,272 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "n3ryhkSfIEfl" + }, + "source": [ + "# Video Tokenization Using [NVIDIA Cosmos Tokenizer](https://github.com/nvidia-cosmos/cosmos-predict1/blob/main/cosmos_predict1/models/tokenizer) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nvidia-cosmos/cosmos-predict1/blob/main/cosmos_predict1/models/tokenizer/notebook/Video_Tokenization.ipynb)\n", + "\n", + "The Jupyter Notebook example utilizes the **Cosmos-Tokenizer** pretrained models, which include Continuous Video (CV) tokenizers that transform videos into continuous spatio-temporal latents and Discrete Video (DI) tokenizers that transform videos into discrete tokens. Both CV and DV tokenizers are available with compression rates of (`TxHxW` format) 4x8x8 and 8x8x8, and 8x16x16. For instance, **CV4x8x8** effectively downsizes the number of frames by a factor of 4 and both height and width by a factor of 8.\n", + "\n", + "Within the notebook, the `VideoTokenizer` class from the `cosmos_tokenizer.video_lib` module is employed to manage the encoder and decoder components of this model. The encoder compresses the input video into a condensed latent representation or discrete integers, while the decoder reconstructs the video from this latent representation or discrete integers.\n", + "\n", + "This instance of the Cosmos Tokenizer demonstrates its autoencoding capability: compressing a video into a smaller latent space and subsequently reconstructing it to its original form. This showcases the efficiency of video tokenization for tasks involving significant spatial compression during video reconstruction, a highly desirable feature for generative modeling.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5BkjyLTPLM6e" + }, + "source": [ + "This tutorial follows a simple, step-by-step approach, making it easy to understand and adapt.\n", + "\n", + "## Step 1: Clone the Cosmos Tokenizer Repository" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TEV88M9YG973" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/NVIDIA-Cosmos/cosmos-predict1.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AxOMEJpFL9QL" + }, + "source": [ + "## Step 2: Install **Cosmos-Tokenizer**\n", + "Before proceeding, ensure you have the **Cosmos Tokenizer** installed. If you cloned the repository in Step 1, use the following command to install it in editable mode:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XuwUR6HrIxD8" + }, + "outputs": [], + "source": [ + "# Step 2: # Install Cosmos-Tokenizer and its Python dependencies.\n", + "import os\n", + "if os.path.exists(\"cosmos-predict1\"):\n", + " os.chdir(\"cosmos-predict1\")\n", + " !apt-get update\n", + " !apt-get install -y git-lfs\n", + " !git lfs pull\n", + " %pip install -r requirements.txt\n", + "else:\n", + " print('cosmos-predict1 is already installed.')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "id29RPiyMOtB" + }, + "source": [ + "## Step 3: Set Up Hugging Face API Token and Download Pretrained Models\n", + "\n", + "In this step, you'll configure the Hugging Face API token and download the pretrained model weights required for the **Cosmos Tokenizer**.\n", + "\n", + "1. **Ensure You Have a Hugging Face Account** \n", + " If you do not already have a Hugging Face account, follow these steps to create one and generate an API token:\n", + " - Go to the [Hugging Face website](https://huggingface.co/) and sign up for a free account.\n", + " - After logging in, navigate to your [Settings → Access Tokens](https://huggingface.co/settings/tokens).\n", + " - Click on \"New Token\" to generate an API token with the required permissions.\n", + "\n", + "2. **Set the Hugging Face Token** \n", + " Check if the Hugging Face token is already set in the environment variables. If not, you will be prompted to enter it manually. The token is essential to authenticate and access the Hugging Face models.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "joxcyOlnM7HQ" + }, + "outputs": [], + "source": [ + "# Check if the token is already set\n", + "if \"HUGGINGFACE_TOKEN\" not in os.environ:\n", + " os.environ[\"HUGGINGFACE_TOKEN\"] = input(\"Please enter your Hugging Face API token: \")\n", + "!git config --global credential.helper store" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Lq7MAQ9pGPH9" + }, + "outputs": [], + "source": [ + "from huggingface_hub import login, snapshot_download\n", + "import os\n", + "HUGGINGFACE_TOKEN = os.environ.get(\"HUGGINGFACE_TOKEN\")\n", + "login(token=HUGGINGFACE_TOKEN, add_to_git_credential=True)\n", + "model_names = [\n", + " \"Cosmos-0.1-Tokenizer-CV4x8x8\",\n", + " \"Cosmos-0.1-Tokenizer-CV8x8x8\",\n", + " \"Cosmos-0.1-Tokenizer-CV8x16x16\",\n", + " \"Cosmos-0.1-Tokenizer-DV4x8x8\",\n", + " \"Cosmos-0.1-Tokenizer-DV8x8x8\",\n", + " \"Cosmos-0.1-Tokenizer-DV8x16x16\",\n", + " \"Cosmos-Tokenize1-CV8x8x8-720p\",\n", + " \"Cosmos-Tokenize1-DV8x16x16-720p\",\n", + "]\n", + "for model_name in model_names:\n", + " hf_repo = \"nvidia/\" + model_name\n", + " local_dir = \"checkpoints/\" + model_name\n", + " os.makedirs(local_dir, exist_ok=True)\n", + " print(f\"downloading {model_name}...\")\n", + " snapshot_download(repo_id=hf_repo, local_dir=local_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ltZ-v2vzNv74" + }, + "source": [ + "## Step 4: Use Cosmos Tokenizer for Video Reconstruction\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 594 + }, + "id": "gZFPrGCBGwtC", + "outputId": "ad18dc16-c1f2-410c-937b-787c677ec27e" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:19<00:00, 6.45s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input video read from:\t /home/freda/Cosmos/cosmos1/models/tokenizer/test_data/video.mp4\n", + "Reconstruction saved:\t /home/freda/Cosmos/cosmos1/models/tokenizer/test_data/video_CV8x8x8.mp4\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
Input Video
\n", + "
\n", + "
Reconstructed Video
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# @title In this step, load the required checkpoints, and perform video reconstruction. {\"run\":\"auto\"}\n", + "import cv2\n", + "import numpy as np\n", + "import torch\n", + "\n", + "import importlib\n", + "from cosmos_predict1.tokenizer.inference.video_lib import CausalVideoTokenizer\n", + "import mediapy as media\n", + "\n", + "\n", + "# 1) Specify the model name, and the paths to the encoder/decoder checkpoints.\n", + "model_name = 'Cosmos-Tokenize1-CV8x8x8-720p' # @param [\"Cosmos-0.1-Tokenizer-CV4x8x8\", \"Cosmos-0.1-Tokenizer-CV8x8x8\", \"Cosmos-0.1-Tokenizer-CV8x16x16\", \"Cosmos-0.1-Tokenizer-DV4x8x8\", \"Cosmos-0.1-Tokenizer-DV8x8x8\", \"Cosmos-0.1-Tokenizer-DV8x16x16\", \"Cosmos-Tokenize1-CV8x8x8-720p\", \"Cosmos-Tokenize1-DV8x16x16-720p\"]\n", + "temporal_window = 49 # @param {type:\"slider\", min:1, max:121, step:8}\n", + "\n", + "encoder_ckpt = f\"checkpoints/{model_name}/encoder.jit\"\n", + "decoder_ckpt = f\"checkpoints/{model_name}/decoder.jit\"\n", + "\n", + "# 2) Load or provide the video filename you want to tokenize & reconstruct.\n", + "input_filepath = \"cosmos_predict1/tokenizer/test_data/video.mp4\"\n", + "\n", + "# 3) Read the video from disk (shape = T x H x W x 3 in BGR).\n", + "input_video = media.read_video(input_filepath)[..., :3]\n", + "assert input_video.ndim == 4 and input_video.shape[-1] == 3, \"Frames must have shape T x H x W x 3\"\n", + "\n", + "# 4) Expand dimensions to B x Tx H x W x C, since the CausalVideoTokenizer expects a batch dimension\n", + "# in the input. (Batch size = 1 in this example.)\n", + "batched_input_video = np.expand_dims(input_video, axis=0)\n", + "\n", + "# 5) Create the CausalVideoTokenizer instance with the encoder & decoder.\n", + "# - device=\"cuda\" uses the GPU\n", + "# - dtype=\"bfloat16\" expects Ampere or newer GPU (A100, RTX 30xx, etc.)\n", + "tokenizer = CausalVideoTokenizer(\n", + " checkpoint_enc=encoder_ckpt,\n", + " checkpoint_dec=decoder_ckpt,\n", + " device=\"cuda\",\n", + " dtype=\"bfloat16\",\n", + ")\n", + "\n", + "# 6) Use the tokenizer to autoencode (encode & decode) the video.\n", + "# The output is a NumPy array with shape = B x T x H x W x C, range [0..255].\n", + "batched_output_video = tokenizer(batched_input_video,\n", + " temporal_window=temporal_window)\n", + "\n", + "# 7) Extract the single video from the batch (index 0).\n", + "output_video = batched_output_video[0]\n", + "\n", + "# 9) Save the reconstructed video to disk.\n", + "input_dir, input_filename = os.path.split(input_filepath)\n", + "filename, ext = os.path.splitext(input_filename)\n", + "output_filepath = f\"{input_dir}/{filename}_{model_name.split('-')[-1]}{ext}\"\n", + "media.write_video(output_filepath, output_video)\n", + "print(\"Input video read from:\\t\", f\"{os.getcwd()}/{input_filepath}\")\n", + "print(\"Reconstruction saved:\\t\", f\"{os.getcwd()}/{output_filepath}\")\n", + "\n", + "# 10) Visualization of the input video (left) and the reconstruction (right).\n", + "media.show_videos([input_video, output_video], [\"Input Video\", \"Reconstructed Video\"], height=480)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/cosmos_predict1/tokenizer/test_data/image.png b/cosmos_predict1/tokenizer/test_data/image.png new file mode 100644 index 0000000000000000000000000000000000000000..370b83e4fd1c42547cbe34190ff726994d2c34a6 --- /dev/null +++ b/cosmos_predict1/tokenizer/test_data/image.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27f2261a585eea38a0c9ec16f2ea81a2295b49c5ad6a3e39fc7cfdd1aa39f53b +size 1786433 diff --git a/cosmos_predict1/tokenizer/test_data/video.mp4 b/cosmos_predict1/tokenizer/test_data/video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..13dfd018f2ba6afb264c862cfc54a83b8d9e5f6b --- /dev/null +++ b/cosmos_predict1/tokenizer/test_data/video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b1112c71ee9f14b0d1d2b60a7b8d76bdd133d8fc788e5b41e104132f75bfb4f +size 3570241 diff --git a/cosmos_predict1/tokenizer/training/__init__.py b/cosmos_predict1/tokenizer/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/tokenizer/training/callbacks.py b/cosmos_predict1/tokenizer/training/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6e11f350a51fcff117a100a7774ca0b0b5961d --- /dev/null +++ b/cosmos_predict1/tokenizer/training/callbacks.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenizer callbacks extended from base callbacks.""" + +import math +from typing import Any, Optional + +import numpy as np +import torch +from torch._dynamo.eval_frame import OptimizedModule as torch_OptimizedModule + +from cosmos_predict1.utils import callback, distributed, log +from cosmos_predict1.utils.config import Config +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + +_UINT8_MAX_F = float(np.iinfo(np.uint8).max) +_VIDEO_CONSISTENCY_LOSS = "video_consistency" + + +def make_video_grid(video, nrow=None, padding=1): + r"""Make a grid of videos for visualization. + Args: + video (tensor): video of size B x C x T x H x W. + nrow (int): number of rows in the grid. + padding (int): size of paddings between videos. + """ + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().detach().numpy() * _UINT8_MAX_F).astype("uint8") + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + video_grid = np.zeros((t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c), dtype="uint8") + + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r : start_r + h, start_c : start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + return video + + +def compute_weight_norm(model): + weight_norm = dict() + for layer_name, param in model.named_parameters(): + if torch.isnan(param).any(): + raise ValueError(f"[weight] {layer_name} NaN detected in gradients") + weight_norm[f"{layer_name}"] = torch.norm(param, p=2).item() + return weight_norm + + +def compute_grad_norm(model): + grad_norm = dict() + for layer_name, param in model.named_parameters(): + if param.grad is not None: + if torch.isnan(param.grad).any(): + raise ValueError(f"[grad] {layer_name} NaN detected in gradients") + grad_norm[f"{layer_name}"] = torch.norm(param.grad, p=2).item() + return grad_norm + + +class AdaptCkptStateDict(callback.Callback): + def __init__(self, config: Config, trainer: Trainer): + super().__init__(config, trainer) + + def on_save_checkpoint(self, model: Model, state_dict: dict[Any, Any]) -> None: + """Adapt the state dict should the model be a compiled one.""" + if not isinstance(model.network, torch_OptimizedModule): + return + + def _uncompiled_key(k): + if k.startswith("network._orig_mod"): + return k.replace("network._orig_mod", "network") + elif k.startswith("ema.network-_orig_mod"): + return k.replace("ema.network-_orig_mod", "ema.network") + return k + + fixed_keys_state_dict = {} + + for k, v in state_dict["model"].items(): + fixed_keys_state_dict[_uncompiled_key(k)] = v + + state_dict["model"] = fixed_keys_state_dict + + def on_load_checkpoint(self, model: Model, state_dict: dict[Any, Any]) -> None: + """Adapt the state dict should the model be a compiled one.""" + if not isinstance(model.network, torch_OptimizedModule): + return + + def _compiled_key(k): + if k.startswith("network."): + return k.replace("network", "network._orig_mod") + elif k.startswith("ema.network-"): + return k.replace("ema.network", "ema.network-_orig_mod") + return k + + fixed_keys_state_dict = {} + + for k, v in state_dict["model"].items(): + fixed_keys_state_dict[_compiled_key(k)] = v + + state_dict["model"] = fixed_keys_state_dict + + +class GradClipCallback(callback.GradClipCallback): + """The verbose tokenizer callback for gradient clipping.""" + + def __init__(self, grad_clip_norm: float, config: Config, trainer: Trainer, verbose: bool): + super().__init__(config, trainer, grad_clip_norm) + self.verbose = verbose + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + grad_scaler.unscale_(optimizer) + total_norm = torch.nn.utils.clip_grad_norm_(model_ddp.module.parameters(), max_norm=self.grad_clip_norm) + if torch.isnan(total_norm): + raise ValueError("[gradient clipping] NaN detected in gradient norms") + if torch.isfinite(total_norm) and total_norm > self.grad_clip_norm and self.verbose: + if model_ddp.module.network.training: + log.warning( + f"[net:{iteration:07d}] Gradient norm {total_norm} > {self.grad_clip_norm}. Clipping gradients." + ) + else: + log.warning( + f"[unknown:{iteration:07d}] Gradient norm {total_norm} > {self.grad_clip_norm}. Clipping gradients." + ) + + +class ExpandLossMask(callback.Callback): + def __init__(self, kernel_size: int, config: Config, trainer: Trainer): + super().__init__(config, trainer) + self.kernel_size = kernel_size + + def on_training_step_start(self, model: Model, data: dict[str, Any], iteration: int = 0) -> None: + """Expand loss_mask with max pooling (to cover some partial human regions)""" + + if "loss_mask" not in data.keys(): + return + + assert data["loss_mask"].ndim == 4 or data["loss_mask"].ndim == 5, "ndim of loss_mask must be 4 or 5" + + kernel_size = self.kernel_size + if data["loss_mask"].ndim == 4: + data["loss_mask"] = torch.nn.functional.max_pool2d( + data["loss_mask"], kernel_size, stride=1, padding=kernel_size // 2 + ) + else: + data["loss_mask"] = torch.nn.functional.max_pool3d( + data["loss_mask"], + (1, kernel_size, kernel_size), + stride=1, + padding=(0, kernel_size // 2, kernel_size // 2), + ) + + +class TorchCompile(callback.Callback): + """ + Callback to use torch.compile() on network or modules in losses(FlowLoss and PerceptualLoss) or both. + We compile them at later iteration as it prevents NCCL timeouts when times are very unstable during first iterations + """ + + _TORCH_DYNAMO_CACHE_SIZE = 128 + + def __init__( + self, + compile_after_iterations: int = 8, + compile_network: bool = False, + compile_loss: bool = False, + compile_loss_keys: list[str] = ["flow", "perceptual"], + ): + self.initial_iteration: Optional[int] = None + self.compile_after_iterations: int = compile_after_iterations + + self.compile_network: bool = compile_network + self.compile_loss: bool = compile_loss + + self.compile_loss_keys: list[str] = compile_loss_keys + + if self.compile_network or self.compile_loss: + torch._dynamo.config.cache_size_limit = TorchCompile._TORCH_DYNAMO_CACHE_SIZE + + # Hack to make ".training" work on "torch.compile()" module. + # Value of ".training" is incorrectly set on torch.compile() module, when .eval() or .train() + # is invoked, but is correctly set on original module and this hack accesses that value + # I've created issue about this: https://github.com/pytorch/pytorch/issues/132986 + torch_OptimizedModule.training = property( + lambda self: self._orig_mod.training, lambda self, value: None, lambda self: None + ) + + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + if not (self.compile_network or self.compile_loss): + return + + if self.initial_iteration is None: + log.info(f"Compilation will done on iteration {iteration + self.compile_after_iterations}") + self.initial_iteration = iteration + + if self.compile_network: + if model.config.ema.enabled is True and model.config.ema.torch_compile_buffer_renaming is False: + log.warning( + '"model.config.ema.torch_compile_buffer_renaming" should be turned on for the EMA to work with torch.compile(), network will not be compiled' + ) + + if iteration - self.initial_iteration == self.compile_after_iterations: + if self.compile_network: + if model.config.ema.enabled is True and model.config.ema.torch_compile_buffer_renaming is False: + log.warning( + '"model.config.ema.torch_compile_buffer_renaming" should be turned on for the EMA to work with torch.compile(), skipping network compilation' + ) + else: + log.info("Compiling network") + model.network = torch.compile(model.network, dynamic=False) + + if self.compile_loss: + for key in self.compile_loss_keys: + if key not in model.loss.loss_modules: + log.warning(f"Loss module for compilation with key: {key} not found") + else: + if ( + hasattr(model.loss.loss_modules[key], "checkpoint_activations") + and getattr(model.loss.loss_modules[key], "checkpoint_activations") is True + ): + log.warning( + f"torch.compile() doesn't work with activation checkpointing, skipping compilation for loss with key: {key}" + ) + else: + log.info(f"Compiling loss with key: {key}") + model.loss.loss_modules[key].torch_compile() diff --git a/cosmos_predict1/tokenizer/training/checkpointer.py b/cosmos_predict1/tokenizer/training/checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb58c0d08149a869cc0e26321142fb8b3754ca2 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/checkpointer.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import threading + +import torch +from torch._dynamo.eval_frame import OptimizedModule as torch_OptimizedModule + +from cosmos_predict1.utils import callback, distributed, ema, log, misc +from cosmos_predict1.utils.checkpointer import Checkpointer +from cosmos_predict1.utils.config import CheckpointConfig, JobConfig +from cosmos_predict1.utils.model import Model + + +class TokenizerCheckpointer(Checkpointer): + """The tokenizer checkpointer, extends the shared checkpointer. + + Supports checkpoint saving/loading to local disk: + - network weights and training optimizer states. + - optionally, export a TorchScript version of the EMA model. + """ + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + super().__init__(config_checkpoint, config_job, callbacks) + self.callbacks = callbacks + self.config_jit = config_checkpoint.jit + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = -1, + **ignore_kwargs, + ) -> None: + """Saves network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer: The model optimizer. + scheduler: The optimization scheduler. + grad_scaler: The gradient scaler (for mixed precision training). + iteration: Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + model.eval() + checkpoint_file = f"iter_{iteration:09}.pt" + + if distributed.get_rank() == 0: + state_dict = dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + scheduler=scheduler.state_dict(), + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + + state_dict = misc.to(state_dict, device="cpu") + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=(state_dict, self._get_ema_jit(model), checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + @misc.timer("checkpoint saving (local)") + def _save_worker_local( + self, + state_dict: dict[str, torch.Tensor], + jit_models: dict[str, torch.ScriptModule], + checkpoint_file: str, + rank: int = 0, + ) -> None: + """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). + + Args: + state_dict: The state dict of the model/optimizer/scheduler. + ema_jit: A dict of TorchScript EMA model, representing the encoder, decoder and full model. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) + os.makedirs(self.checkpoint_dir_local, exist_ok=True) + try: + torch.save(state_dict, checkpoint_path) + for key, jit_model in jit_models.items(): + checkpoint_jit = checkpoint_path.replace(".pt", f"_{key}.jit") + torch.jit.save(jit_model, checkpoint_jit) + log.success(f"Saved checkpoint: {checkpoint_jit}") + if rank == 0: + self._write_latest_checkpoint_file(checkpoint_file) + log.success(f"Saved checkpoint (local): {checkpoint_path}") + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to save (local): {e}") + + def _get_ema_jit(self, model: Model) -> dict[str, torch.ScriptModule]: + """Returns a TorchScript version of ema models compiled by PyTorch JIT.""" + if not self.config_jit.enabled: + return dict() + input_shape = tuple(self.config_jit.input_shape) + example_input = torch.randn(input_shape) + dtype = getattr(torch, self.config_jit.dtype) + example_input = example_input.to(self.config_jit.device).to(dtype) + with ema.ema_scope(model, enabled=model.config.ema.enabled): + _model = model.network + if isinstance(_model, torch_OptimizedModule): + _model = _model._orig_mod + + # Make sure jit model output consistenly during consecutive calls + # Check here: https://github.com/pytorch/pytorch/issues/74534 + torch._C._jit_set_texpr_fuser_enabled(False) + + ema_jit = torch.jit.trace(_model, example_input, strict=self.config_jit.strict) + encoder_jit = torch.jit.trace(_model.encoder_jit(), example_input, strict=self.config_jit.strict) + decoder_example = encoder_jit(example_input) + if isinstance(decoder_example, tuple): + decoder_example = decoder_example[0] + else: + assert isinstance(decoder_example, torch.Tensor), "decoder_example should be a tensor or tuple" + decoder_jit = torch.jit.trace(_model.decoder_jit(), decoder_example, strict=self.config_jit.strict) + return {"ema": ema_jit, "enc": encoder_jit, "dec": decoder_jit} diff --git a/cosmos_predict1/tokenizer/training/configs/__init__.py b/cosmos_predict1/tokenizer/training/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/tokenizer/training/configs/base/__init__.py b/cosmos_predict1/tokenizer/training/configs/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/tokenizer/training/configs/base/callback.py b/cosmos_predict1/tokenizer/training/configs/base/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..941ea69d78ef30651ed53b600cfe43c424bcdb3d --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/callback.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""callbacks config options: + +BASIC_CALLBACKS: always recommended to use +""" + +from cosmos_predict1.tokenizer.training.callbacks import ( + AdaptCkptStateDict, + ExpandLossMask, + GradClipCallback, + TorchCompile, +) +from cosmos_predict1.utils.callback import EMAModelCallback, LowPrecisionCallback, ProgressBarCallback +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L + +BASIC_CALLBACKS = dict( + low_precision=L(LowPrecisionCallback)(update_iter=1, config=PLACEHOLDER, trainer=PLACEHOLDER), + grad_clip=L(GradClipCallback)(grad_clip_norm=1, verbose=False, config=PLACEHOLDER, trainer=PLACEHOLDER), + ema=L(EMAModelCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER), + progress_bar=L(ProgressBarCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER), + expand_loss_mask=L(ExpandLossMask)(kernel_size=51, config=PLACEHOLDER, trainer=PLACEHOLDER), + adapt_ckpt_state_dict=L(AdaptCkptStateDict)(config=PLACEHOLDER, trainer=PLACEHOLDER), + torch_compile=L(TorchCompile)( + compile_after_iterations=8, + compile_network=False, + compile_loss=False, + compile_loss_keys=["flow", "perceptual"], + ), +) diff --git a/cosmos_predict1/tokenizer/training/configs/base/checkpoint.py b/cosmos_predict1/tokenizer/training/configs/base/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f869223833a0027c29ee6d9fae4bf460a85e0639 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/checkpoint.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""checkpoints config options: + +CHECKPOINT_LOCAL: store at local file system + +""" +import attrs + +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config import make_freezable +from cosmos_predict1.utils.lazy_config import LazyDict + + +@make_freezable +@attrs.define(slots=False) +class ExperimentConfig: + # Enables enforcing experiment naming. + enabled: bool = True + # The project, e.g. edify_video4. + project: str = None + # The valid groups, e.g ["video"]. + groups: list[str] = None + # The approved name prefixes, e.g. ["DV1024", "DI256"]. + name_prefixes: list[str] = None + + +@make_freezable +@attrs.define(slots=False) +class TokenizerCheckpointConfig(config.CheckpointConfig): + # Experiment naming configs. + experiment: ExperimentConfig = attrs.field(factory=ExperimentConfig) + + +jit_config = config.JITConfig( + enabled=True, + input_shape=[1, 3, 1024, 1024], +) + +experiment_config = ExperimentConfig( + enabled=True, + project="cosmos_tokenizer", + groups=["debug", "video"], + name_prefixes=[ + f"{base}{size}" if base in ["CI", "DI"] else f"{base}{size}_Causal" + for base in ["CI", "DI", "CV", "DV"] + for size in [256, 320, 480, 512, 720, 1024, 1080] + ] + + [f"{base}{size}" for base in ["CV", "DV"] for size in [256, 320, 512, 720]] + + ["mock"], +) + +CHECKPOINT_LOCAL: LazyDict = attrs.asdict( + TokenizerCheckpointConfig( + save_iter=5000, + jit=jit_config, + experiment=experiment_config, + ) +) diff --git a/cosmos_predict1/tokenizer/training/configs/base/data.py b/cosmos_predict1/tokenizer/training/configs/base/data.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a698489074d8dfde2bf7dafbdb68f24203ede8 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/data.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""dataloader config options + +Available dataloader options: + image_loader_basic + video_loader_basic + joint_image_video_loader_basic +""" + +from torch.utils.data import DataLoader + +from cosmos_predict1.tokenizer.training.configs.base.mock_data import get_mock_video_dataloader +from cosmos_predict1.tokenizer.training.datasets.dataset_provider import dataset_entry +from cosmos_predict1.utils.lazy_config import LazyCall + +DATALOADER_OPTIONS = {} + + +def dataloader_register(key): + def decorator(func): + DATALOADER_OPTIONS[key] = func + return func + + return decorator + + +@dataloader_register("video_loader_basic") +def get_video_dataloader( + dataset_name, + is_train, + batch_size=1, + num_video_frames=25, + resolution="720", + crop_height=128, + num_workers=8, +): + if dataset_name.startswith("mock"): + return get_mock_video_dataloader( + batch_size=batch_size, + is_train=is_train, + num_video_frames=num_video_frames, + resolution=resolution, + crop_height=crop_height, + ) + return LazyCall(DataLoader)( + dataset=LazyCall(dataset_entry)( + dataset_name=dataset_name, + dataset_type="video", + is_train=is_train, + resolution=resolution, + crop_height=crop_height, + num_video_frames=num_video_frames, + ), + batch_size=batch_size, # 2 + num_workers=num_workers, # 8 + prefetch_factor=2, + shuffle=None, # do we need this? + sampler=None, + persistent_workers=False, + pin_memory=True, + ) diff --git a/cosmos_predict1/tokenizer/training/configs/base/loss.py b/cosmos_predict1/tokenizer/training/configs/base/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b4fcd754a5f0913fbf28e376773bde03bb4f9806 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/loss.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loss config options + +Loss weights are scheduled using a piecewise linear LR schedule. The schedule is defined by a list of boundaries and values. + +`boundaries` is a list of integers representing the iteration at which the weight value changes. +`values` is a list of floats representing the weight value at each boundary. It should have one more value than `boundaries`. + +Example: + A loss's weight will be: + values[0] when step <= boundaries[0], + values[1] when step > boundaries[0] and step <= boundaries[1], + ..., and + values[-1] when step > boundaries[-1]. +""" +import attrs + +from cosmos_predict1.tokenizer.training.losses import ReduceMode +from cosmos_predict1.tokenizer.training.losses.continuous import ( + ColorLoss, + FlowLoss, + KLLoss, + PerceptualLoss, + TokenizerLoss, + VideoConsistencyLoss, +) +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class KLConfig: + # each step is greater than boundaries[-1], so weight=values[-1] + boundaries: list[int] = [0] + values: list[float] = [1e-6] + + +@attrs.define(slots=False) +class PerceptualConfig: + lpips_boundaries: list[int] = [500000] + lpips_values: list[float] = [0.1, 0.073] + # Layer weights for linearly combining the multi-layer vgg-based losses. + layer_weights: list[float] = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] + # Gram loss, whether to turn on, and what weights to use. + gram_enabled: bool = True + gram_boundaries: list[int] = [500000] + gram_values: list[float] = [0.0, 0.062] + # Corr loss, whether to turn on, and what weights to use. + corr_enabled: bool = False + corr_boundaries: list[int] = [0] + corr_values: list[float] = [0.0] + # In the example training memory usage dropped from 64.03 GiB to 60.54 GiB + # with checkpointing enabled for this loss for about 3.2% slowdown. + # With checkpointing this and PerceptualLoss memory usage dropped + # from 64.03 GiB to 52.94 GiB for about 18% slowdown + # more details in MR:949 + checkpoint_activations: bool = False + + +@attrs.define(slots=False) +class ColorConfig: + # Color (RGB) basic loss and its weight schedule. + norm: str = "L1" + boundaries: list[int] = [0] + values: list[float] = [1.0] + + +@attrs.define(slots=False) +class FlowConfig: + # Flow loss and its weight schedule. + boundaries: list[int] = [250000] + values: list[float] = [0.0, 0.01] + scale: int = 2 + # Flow loss depends on RAFT, as such it requires a specific dtype. + dtype: str = "bfloat16" + # In the example training memory usage dropped from 28GB to 23GB + # with checkpointing enabled for this loss + # With checkpointing this and PerceptualLoss memory usage dropped + # from 64.03 GiB to 52.94 GiB for about 18% slowdown + # more details in MR:949 + checkpoint_activations: bool = False + enabled: bool = False + + +@attrs.define(slots=False) +class VideoConsistencyConfig: + # Add consistency loss between overlapped video frames + boundaries: list[int] = [250000] + values: list[float] = [0.0, 0.01] + enabled: bool = False + num_frames: int = 9 + step: int = 1 + + +@attrs.define(slots=False) +class VideoLoss: + # The combined loss function, and its reduction mode. + color: LazyDict = L(ColorLoss)(config=ColorConfig()) + kl: LazyDict = L(KLLoss)(config=KLConfig()) + perceptual: LazyDict = L(PerceptualLoss)(config=PerceptualConfig()) + flow: LazyDict = L(FlowLoss)(config=FlowConfig()) + video_consistency: LazyDict = L(VideoConsistencyLoss)(config=VideoConsistencyConfig()) + reduce: str = ReduceMode.MEAN.value # model.config.loss.config.reduce={'MEAN', 'SUM', 'SUM_PER_FRAME'} + + +VideoLossConfig: LazyDict = L(TokenizerLoss)(config=VideoLoss()) diff --git a/cosmos_predict1/tokenizer/training/configs/base/metric.py b/cosmos_predict1/tokenizer/training/configs/base/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..daecc1cb8beb9bbc2c18ec1acb978ee3b5fdcdab --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/metric.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metric configurations for the tokenizer model. + +Support for PSNR or SSIM, there are validation only metrics. +""" +import attrs + +from cosmos_predict1.tokenizer.training.metrics import CodeUsageMetric, PSNRMetric, SSIMMetric, TokenizerMetric +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class Metric: + # The combined loss function, and its reduction mode. + PSNR: LazyDict = L(PSNRMetric)() + SSIM: LazyDict = L(SSIMMetric)() + + +@attrs.define(slots=False) +class DiscreteTokenizerMetric: + # with code usage (perplexity PPL), for discrete tokenizers only + PSNR: LazyDict = L(PSNRMetric)() + SSIM: LazyDict = L(SSIMMetric)() + CodeUsage: LazyDict = L(CodeUsageMetric)(codebook_size=64000) + + +MetricConfig: LazyDict = L(TokenizerMetric)(config=Metric()) + +DiscreteTokenizerMetricConfig: LazyDict = L(TokenizerMetric)(config=DiscreteTokenizerMetric()) diff --git a/cosmos_predict1/tokenizer/training/configs/base/mock_data.py b/cosmos_predict1/tokenizer/training/configs/base/mock_data.py new file mode 100644 index 0000000000000000000000000000000000000000..47e5a8efa694cdb00873b2c38d0c224be9109db3 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/mock_data.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.utils.data import DataLoader + +from cosmos_predict1.tokenizer.training.datasets.mock_dataset import CombinedDictDataset, LambdaDataset +from cosmos_predict1.tokenizer.training.datasets.utils import VIDEO_KEY, VIDEO_VAL_CROP_SIZE_INFO, get_crop_size_info +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +_IMAGE_ASPECT_RATIO = "1,1" +_VIDEO_ASPECT_RATIO = "16,9" + + +def get_video_dataset( + is_train: bool, + resolution: str, + crop_height: int, + num_video_frames: int, +): + if is_train: + crop_sizes = get_crop_size_info(crop_height) + log.info( + f"[video] training num_frames={num_video_frames}, crop_height={crop_height} and crop_sizes: {crop_sizes}." + ) + else: + if crop_height is None: + crop_sizes = VIDEO_VAL_CROP_SIZE_INFO[resolution] + else: + crop_sizes = get_crop_size_info(crop_height) + log.info(f"[video] validation num_frames={num_video_frames}, crop_sizes: {crop_sizes}") + + h = crop_sizes[_VIDEO_ASPECT_RATIO][1] + w = crop_sizes[_VIDEO_ASPECT_RATIO][0] + + def video_fn(): + return 2 * torch.rand(3, num_video_frames, h, w) - 1 + + return CombinedDictDataset( + **{ + VIDEO_KEY: LambdaDataset(video_fn), + } + ) + + +def get_mock_video_dataloader( + batch_size: int, is_train: bool = True, num_video_frames: int = 9, resolution: str = "720", crop_height: int = 128 +) -> LazyDict: + """A function to get mock video dataloader. + + Args: + batch_size: The batch size. + num_video_frames: The number of video frames. + resolution: The resolution. Defaults to "1024". + + Returns: + LazyDict: A LazyDict object specifying the video dataloader. + """ + if resolution not in VIDEO_VAL_CROP_SIZE_INFO: + resolution = "720" + return L(DataLoader)( + dataset=L(get_video_dataset)( + is_train=is_train, + resolution=resolution, + crop_height=crop_height, + num_video_frames=num_video_frames, + ), + batch_size=batch_size, + shuffle=False, + num_workers=8, + ) diff --git a/cosmos_predict1/tokenizer/training/configs/base/model.py b/cosmos_predict1/tokenizer/training/configs/base/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2f859195d4b81253bfba583e1429ed6749b33e --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/model.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import attrs + +from cosmos_predict1.tokenizer.training.model import TokenizerModel +from cosmos_predict1.utils.config import EMAConfig +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + + +@attrs.define(slots=False) +class ModelConfig: + network: LazyDict = None + loss: LazyDict = None + metric: LazyDict = None + ema: EMAConfig = EMAConfig(enabled=True, beta=0.9999) + precision: str = "bfloat16" + torch_compile: bool = False + disc: LazyDict = None + disc_optimizer: LazyDict = None + disc_scheduler: LazyDict = None + + +DefaultModelConfig: LazyDict = L(TokenizerModel)(config=ModelConfig()) diff --git a/cosmos_predict1/tokenizer/training/configs/base/net.py b/cosmos_predict1/tokenizer/training/configs/base/net.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0ae02ff7b34a67ce0c90324e80f6d6c6c5d153 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/net.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Net config options for cosmos/tokenizer + +ContinuousImageTokenizerConfig +DiscreteImageTokenizerConfig +CausalContinuousVideoTokenizerConfig + +""" + +from cosmos_predict1.tokenizer.modules import ( + ContinuousFormulation, + Decoder3DType, + DecoderType, + DiscreteQuantizer, + Encoder3DType, + EncoderType, +) +from cosmos_predict1.tokenizer.networks.continuous_image import ContinuousImageTokenizer +from cosmos_predict1.tokenizer.networks.continuous_video import CausalContinuousVideoTokenizer +from cosmos_predict1.tokenizer.networks.discrete_image import DiscreteImageTokenizer +from cosmos_predict1.tokenizer.networks.discrete_video import CausalDiscreteVideoTokenizer +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict + +ContinuousImageTokenizerConfig: LazyDict = L(ContinuousImageTokenizer)( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio, default 8. + spatial_compression=8, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The output latent dimension (channels). + latent_channels=16, + # The encoder output channels just before sampling. + # Which is also the decoder's input channels. + z_channels=16, + # A factor over the z_channels, to get the total channels the encoder should output. + # For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels. + # Since we are using AE formulation, we only need the mean, so z_factor=1. + z_factor=1, + name="ContinuousImageTokenizer", + # What formulation to use, either "AE" or "VAE". + # Chose AE here, since this has been proven to be effective. + formulation=ContinuousFormulation.AE.name, + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) + +DiscreteImageTokenizerConfig: LazyDict = L(DiscreteImageTokenizer)( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio. + spatial_compression=16, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The encoder output channels just before sampling. + z_channels=256, + # A factor over the z_channels, to get the total channels the encoder should output. + # for discrete tokenization, often we directly use the vector, so z_factor=1. + z_factor=1, + # The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ. Default FSQ. + quantizer=DiscreteQuantizer.FSQ.name, + # The embedding dimension post-quantization, which is also the input channels of the decoder. + # Which is also the output + embedding_dim=6, + # The number of levels to use for fine-scalar quantization. + levels=[8, 8, 8, 5, 5, 5], + persistent_quantizer=False, + # The number of quantizers to use for residual fine-scalar quantization. + num_quantizers=4, + name="DiscreteImageTokenizer", + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) + +CausalContinuousFactorizedVideoTokenizerConfig: LazyDict = L(CausalContinuousVideoTokenizer)( + # The new causal continuous tokenizer, that is at least 2x more efficient in memory and runtime. + # - It relies on fully 3D discrete wavelet transform + # - Uses a layer norm instead of a group norm + # - Factorizes full convolutions into spatial and temporal convolutions + # - Factorizes full attention into spatial and temporal attention + # - Adopts an AE formulation + # - Strictly causal, with flexible temporal length at inference. + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + latent_channels=16, + z_channels=16, + z_factor=1, + num_groups=1, + # Most of the CV and DV tokenizers trained before September 1, 2024, + # used temporal upsampling that was not perfectly mirrored with the + # # encoder's temporal downsampling. Moving forward, new CV/DV tokenizers + # will use legacy_mode=False, meaning they will adopt mirrored upsampling. + legacy_mode=False, + spatial_compression=8, + temporal_compression=8, + formulation=ContinuousFormulation.AE.name, + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="CausalContinuousFactorizedVideoTokenizer", +) + +CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)( + # The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime. + # - It relies on fully 3D discrete wavelet transform + # - Uses a layer norm instead of a group norm + # - Factorizes full convolutions into spatial and temporal convolutions + # - Factorizes full attention into spatial and temporal attention + # - Strictly causal, with flexible temporal length at inference. + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The encoder output channels just before quantization is changed to 256 + # from 16 (old versions). It aligns with the DI that uses 256 channels, + # making initialization from image tokenizers easier. + z_channels=256, + z_factor=1, + num_groups=1, + # Most of the CV and DV tokenizers trained before September 1, 2024, + # used temporal upsampling that was not perfectly mirrored with the + # # encoder's temporal downsampling. Moving forward, new CV/DV tokenizers + # will use legacy_mode=False, meaning they will adopt mirrored upsampling. + legacy_mode=False, + spatial_compression=16, + temporal_compression=8, + quantizer=DiscreteQuantizer.FSQ.name, + embedding_dim=6, + levels=[8, 8, 8, 5, 5, 5], + persistent_quantizer=False, + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="CausalDiscreteFactorizedVideoTokenizer", +) diff --git a/cosmos_predict1/tokenizer/training/configs/base/optim.py b/cosmos_predict1/tokenizer/training/configs/base/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb3b7eef0c7d9c973806fb3cfc935a409c27fd9 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/base/optim.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""optimizer config options: + +fused_adam - FusedAdamConfig +adamw - AdamWConfig +""" + +import torch + +from cosmos_predict1.utils import fused_adam +from cosmos_predict1.utils.lazy_config import PLACEHOLDER +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.scheduler import WarmupCosineLR, WarmupLambdaLR + +FusedAdamConfig: LazyDict = L(fused_adam.FusedAdam)( + capturable=True, + master_weights=True, + adam_w_mode=True, + params=PLACEHOLDER, + lr=1e-4, + betas=(0.5, 0.999), + eps=1e-8, + weight_decay=0.01, +) + +AdamWConfig: LazyDict = L(torch.optim.AdamW)( + params=PLACEHOLDER, + lr=1e-4, + betas=(0.5, 0.999), + eps=1e-8, + weight_decay=0.01, +) + +WarmupLRConfig: LazyDict = L(WarmupLambdaLR)(optimizer=PLACEHOLDER, warmup=5000) + +FusedAdamDiscConfig: LazyDict = L(fused_adam.FusedAdam)( + capturable=True, + master_weights=True, + adam_w_mode=True, + params=PLACEHOLDER, + lr=4e-4, + betas=(0.5, 0.999), + eps=1e-8, + weight_decay=0.01, +) + +WarmupLRDiscConfig: LazyDict = L(WarmupLambdaLR)(optimizer=PLACEHOLDER, warmup=5000) + +WarmupCosineLRConfig: LazyDict = L(WarmupCosineLR)( + optimizer=PLACEHOLDER, warmup_iters=5000, lr_decay_iters=1000000, min_lr=1e-8 +) diff --git a/cosmos_predict1/tokenizer/training/configs/config.py b/cosmos_predict1/tokenizer/training/configs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..afb94f8813da7d78f48a79a98ef362eed681c5ef --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/config.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default config for cosmos/tokenizer project.""" + +from typing import Any, List + +import attrs + +from cosmos_predict1.tokenizer.training.configs.base.model import DefaultModelConfig +from cosmos_predict1.tokenizer.training.configs.registry import register_configs +from cosmos_predict1.tokenizer.training.trainer import TokenizerTrainer +from cosmos_predict1.utils import config +from cosmos_predict1.utils.config_helper import import_all_modules_from_package + + +@attrs.define(slots=False) +class Config(config.Config): + defaults: List[Any] = attrs.field( + factory=lambda: [ + "_self_", + {"data_train": "mock_video720"}, + {"data_val": "mock_video720"}, + {"optimizer": "fused_adam"}, + {"scheduler": "warmup"}, + {"network": "continuous_factorized_video"}, + {"loss": "video"}, + {"metric": "reconstruction"}, + {"checkpoint": "local"}, + {"callbacks": "basic"}, + {"experiment": None}, + ] + ) + + +def make_config(): + c = Config( + model=DefaultModelConfig, + optimizer=None, + scheduler=None, + dataloader_train=None, + dataloader_val=None, + checkpoint=None, + ) + c.job.project = "posttraining" + c.job.group = "debug" + c.job.name = "default_${now:%Y-%m-%d}_${now:%H-%M-%S}" + + c.trainer.type = TokenizerTrainer + c.trainer.run_validation = True + + c.trainer.seed = 1234 + c.trainer.max_iter = 10_000_000 + c.trainer.validation_iter = 5000 + c.trainer.max_val_iter = 1 + c.trainer.logging_iter = 100 + + c.trainer.callbacks = None + c.trainer.ddp.static_graph = True + c.trainer.ddp.find_unused_parameters = False + register_configs() + + # experiment config are defined in the experiment folder + # call import_all_modules_from_package to register them + import_all_modules_from_package("cosmos_predict1.tokenizer.training.configs.experiments") + + return c diff --git a/cosmos_predict1/tokenizer/training/configs/experiments/__init__.py b/cosmos_predict1/tokenizer/training/configs/experiments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/tokenizer/training/configs/experiments/basic.py b/cosmos_predict1/tokenizer/training/configs/experiments/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..9201f00af0de69e570cbcb2588f468125e2b44ba --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/experiments/basic.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config settings for cosmos/tokenizer (basic image and video setting)""" + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.utils.lazy_config import LazyDict + +CAUSAL_VIDEO_BASIC: LazyDict = LazyDict( + dict( + defaults=[ + {"override /network": "continuous_factorized_video"}, + {"override /data_train": "mock_video720"}, + {"override /data_val": "mock_video720"}, + {"override /loss": "video"}, + {"override /optimizer": "fused_adam"}, + {"override /callbacks": ["basic"]}, + "_self_", + ], + model=dict( + config=dict( + loss=dict( + config=dict( + perceptual=dict( + config=dict( + lpips_boundaries=[0], + lpips_values=[0.1], + gram_enabled=False, + gram_boundaries=[0], + ) + ), + video_consistency=dict( + config=dict( + enabled=False, + boundaries=[0], + values=[1.0], + num_frames=32, + step=8, + ) + ), + flow=dict( + config=dict( + enabled=False, + boundaries=[1_000_000], + values=[0.0, 0.01], + scale=2, + dtype="bfloat16", + checkpoint_activations=False, + ) + ), + ) + ) + ) + ), + dataloader_train=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + dataloader_val=dict( + dataset=dict( + crop_height=720, + num_video_frames=49, + ), + batch_size=1, + ), + job=dict( + project="posttraining", + group="tokenizer", + name="basic_${now:%Y-%m-%d}_${now:%H-%M-%S}", + ), + checkpoint=dict(load_path=None, jit=dict(input_shape=[1, 3, 17, 512, 512])), + ) +) + +cs = ConfigStore.instance() +cs.store(group="experiment", package="_global_", name="video_basic", node=CAUSAL_VIDEO_BASIC) diff --git a/cosmos_predict1/tokenizer/training/configs/experiments/cosmos_tokenize1.py b/cosmos_predict1/tokenizer/training/configs/experiments/cosmos_tokenize1.py new file mode 100644 index 0000000000000000000000000000000000000000..42c1d4f27c2cb7806d0482287f0909d49342ca9b --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/experiments/cosmos_tokenize1.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.tokenizer.training.configs.experiments.utils import create_debug_job_with_mock_data +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import LazyDict + +# Post-training config for Cosmos-Tokenize1-CV8x8x8-720p-HDVILA +Cosmos_Tokenize1_CV8x8x8_720p_HDVILA: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/video_basic", + {"override /network": "continuous_factorized_video"}, + {"override /data_train": "hdvila_video720"}, + {"override /data_val": "hdvila_video720"}, + "_self_", + ], + dataloader_train=dict( + dataset=dict( + crop_height=256, + num_video_frames=121, + ), + batch_size=1, + ), + dataloader_val=dict( + dataset=dict( + crop_height=256, + num_video_frames=121, + ), + batch_size=1, + ), + model=dict( + config=dict( + network=dict( + channels_mult=[2, 4, 4], + patch_size=4, + legacy_mode=False, + temporal_compression=8, + spatial_compression=8, + ) + ) + ), + job=dict( + project="posttraining", + group="tokenizer", + name="Cosmos-Tokenize1-CV8x8x8-720p-HDVILA", + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/model.pt", + strict_resume=True, + load_training_state=True, + jit=dict(input_shape=[1, 3, 17, 512, 512]), + ), + ) +) + +# Post-training config for Cosmos-Tokenize1-DV8x16x16-720p-HDVILA +Cosmos_Tokenize1_DV8x16x16_720p_HDVILA: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/video_basic", + {"override /network": "discrete_factorized_video"}, + {"override /data_train": "hdvila_video720"}, + {"override /data_val": "hdvila_video720"}, + "_self_", + ], + dataloader_train=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + dataloader_val=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + model=dict( + config=dict( + network=dict( + persistent_quantizer=False, + z_channels=16, + channels_mult=[2, 4, 4], + patch_size=4, + legacy_mode=False, + temporal_compression=8, + spatial_compression=16, + ) + ) + ), + job=dict( + project="posttraining", + group="tokenizer", + name="Cosmos-Tokenize1-DV8x16x16-720p-HDVILA", + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/model.pt", + strict_resume=True, + load_training_state=True, + jit=dict(input_shape=[1, 3, 17, 512, 512]), + ), + ) +) + +# Post-training config for Cosmos-Tokenize1-CV4x8x8-360p-HDVILA +Cosmos_Tokenize1_CV4x8x8_360p_HDVILA: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/video_basic", + {"override /network": "continuous_factorized_video"}, + {"override /data_train": "hdvila_video360"}, + {"override /data_val": "hdvila_video360"}, + "_self_", + ], + dataloader_train=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + dataloader_val=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + model=dict( + config=dict( + network=dict( + channels_mult=[2, 4, 4], + patch_size=2, + legacy_mode=False, + temporal_compression=4, + spatial_compression=8, + ) + ) + ), + job=dict( + project="posttraining", + group="tokenizer", + name="Cosmos-Tokenize1-CV4x8x8-360p-HDVILA", + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Tokenize1-CV4x8x8-360p/model.pt", + strict_resume=True, + load_training_state=True, + jit=dict(input_shape=[1, 3, 17, 512, 512]), + ), + ) +) + +# Post-training config for Cosmos-Tokenize1-DV4x8x8-360p-HDVILA +Cosmos_Tokenize1_DV4x8x8_360p_HDVILA: LazyDict = LazyDict( + dict( + defaults=[ + "/experiment/video_basic", + {"override /network": "discrete_factorized_video"}, + {"override /data_train": "hdvila_video360"}, + {"override /data_val": "hdvila_video360"}, + "_self_", + ], + dataloader_train=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + dataloader_val=dict( + dataset=dict( + crop_height=256, + num_video_frames=49, + ), + batch_size=1, + ), + model=dict( + config=dict( + network=dict( + persistent_quantizer=False, + z_channels=256, + channels_mult=[2, 4, 4], + patch_size=2, + legacy_mode=False, + temporal_compression=4, + spatial_compression=8, + ) + ) + ), + job=dict( + project="posttraining", + group="tokenizer", + name="Cosmos-Tokenize1-DV4x8x8-360p-HDVILA", + ), + checkpoint=dict( + load_path="checkpoints/Cosmos-Tokenize1-DV4x8x8-360p/model.pt", + strict_resume=True, + load_training_state=True, + jit=dict(input_shape=[1, 3, 17, 512, 512]), + ), + ) +) + +cs = ConfigStore.instance() + +for _item in [ + Cosmos_Tokenize1_CV8x8x8_720p_HDVILA, + Cosmos_Tokenize1_DV8x16x16_720p_HDVILA, + Cosmos_Tokenize1_CV4x8x8_360p_HDVILA, + Cosmos_Tokenize1_DV4x8x8_360p_HDVILA, +]: + experiment_name = [name for name, value in globals().items() if value is _item][0] + + log.info(f"Registering experiment: {experiment_name}") + cs.store( + group="experiment", + package="_global_", + name=experiment_name, + node=_item, + ) + + mock_experiment = f"mock_{experiment_name}" + log.info(f"Registering mock experiment: {mock_experiment}") + _debug_item = create_debug_job_with_mock_data(_item["job"]["name"]) + cs.store( + group="experiment", + package="_global_", + name=mock_experiment, + node=_debug_item, + ) diff --git a/cosmos_predict1/tokenizer/training/configs/experiments/utils.py b/cosmos_predict1/tokenizer/training/configs/experiments/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4d68db0312c7abe891b99d55683a776ba6a1e43e --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/experiments/utils.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""registry for commandline override options for config.""" +from cosmos_predict1.utils.lazy_config import LazyDict + + +def create_debug_job_with_mock_data(full_experiment_name): + job_dict = dict( + defaults=[ + f"/experiment/{full_experiment_name.replace('-', '_')}", + {"override /data_train": "mock_video360"}, + {"override /data_val": "mock_video360"}, + "_self_", + ], + job=dict(group="debug", name=f"mock_{full_experiment_name}" + "_${now:%Y-%m-%d}_${now:%H-%M-%S}"), + trainer=dict( + max_iter=2, + logging_iter=1, + max_val_iter=1, + validation_iter=2, + ), + checkpoint=dict( + strict_resume=False, + load_training_state=False, + save_iter=2, + ), + ) + return LazyDict(job_dict) diff --git a/cosmos_predict1/tokenizer/training/configs/registry.py b/cosmos_predict1/tokenizer/training/configs/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..cda4109b0b6edbc0cbfa29f9af676f359dc90f1f --- /dev/null +++ b/cosmos_predict1/tokenizer/training/configs/registry.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""registry for commandline override options for config.""" +from hydra.core.config_store import ConfigStore + +from cosmos_predict1.tokenizer.training.configs.base.callback import BASIC_CALLBACKS +from cosmos_predict1.tokenizer.training.configs.base.checkpoint import CHECKPOINT_LOCAL +from cosmos_predict1.tokenizer.training.configs.base.data import DATALOADER_OPTIONS +from cosmos_predict1.tokenizer.training.configs.base.loss import VideoLossConfig +from cosmos_predict1.tokenizer.training.configs.base.metric import DiscreteTokenizerMetricConfig, MetricConfig +from cosmos_predict1.tokenizer.training.configs.base.net import ( + CausalContinuousFactorizedVideoTokenizerConfig, + CausalDiscreteFactorizedVideoTokenizerConfig, + ContinuousImageTokenizerConfig, + DiscreteImageTokenizerConfig, +) +from cosmos_predict1.tokenizer.training.configs.base.optim import ( + AdamWConfig, + FusedAdamConfig, + WarmupCosineLRConfig, + WarmupLRConfig, +) + + +def register_training_data(cs): + for data_source in ["mock", "hdvila"]: + for resolution in ["1080", "720", "480", "360", "256"]: + cs.store( + group="data_train", + package="dataloader_train", + name=f"{data_source}_video{resolution}", # `davis_video720` + node=DATALOADER_OPTIONS["video_loader_basic"]( + dataset_name=f"{data_source}_video", + is_train=True, + resolution=resolution, + ), + ) + + +def register_val_data(cs): + for data_source in ["mock", "hdvila"]: + for resolution in ["1080", "720", "480", "360", "256"]: + cs.store( + group="data_val", + package="dataloader_val", + name=f"{data_source}_video{resolution}", # `davis_video720` + node=DATALOADER_OPTIONS["video_loader_basic"]( + dataset_name=f"{data_source}_video", + is_train=False, + resolution=resolution, + ), + ) + + +def register_net(cs): + cs.store( + group="network", package="model.config.network", name="continuous_image", node=ContinuousImageTokenizerConfig + ) + cs.store(group="network", package="model.config.network", name="discrete_image", node=DiscreteImageTokenizerConfig) + + cs.store( + group="network", + package="model.config.network", + name="continuous_factorized_video", + node=CausalContinuousFactorizedVideoTokenizerConfig, + ) + cs.store( + group="network", + package="model.config.network", + name="discrete_factorized_video", + node=CausalDiscreteFactorizedVideoTokenizerConfig, + ) + + +def register_optim(cs): + cs.store(group="optimizer", package="optimizer", name="fused_adam", node=FusedAdamConfig) + cs.store(group="optimizer", package="optimizer", name="adamw", node=AdamWConfig) + + +def register_scheduler(cs): + cs.store(group="scheduler", package="scheduler", name="warmup", node=WarmupLRConfig) + cs.store( + group="scheduler", + package="scheduler", + name="warmup_cosine", + node=WarmupCosineLRConfig, + ) + + +def register_loss(cs): + cs.store(group="loss", package="model.config.loss", name="video", node=VideoLossConfig) + + +def register_metric(cs): + cs.store(group="metric", package="model.config.metric", name="reconstruction", node=MetricConfig) + cs.store(group="metric", package="model.config.metric", name="code_usage", node=DiscreteTokenizerMetricConfig) + + +def register_checkpoint(cs): + cs.store(group="checkpoint", package="checkpoint", name="local", node=CHECKPOINT_LOCAL) + + +def register_callback(cs): + cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS) + + +def register_configs(): + cs = ConfigStore.instance() + + register_training_data(cs) + register_val_data(cs) + + register_net(cs) + + register_optim(cs) + register_scheduler(cs) + register_loss(cs) + register_metric(cs) + register_checkpoint(cs) + + register_callback(cs) diff --git a/cosmos_predict1/tokenizer/training/datasets/__init__.py b/cosmos_predict1/tokenizer/training/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cosmos_predict1/tokenizer/training/datasets/augmentation_provider.py b/cosmos_predict1/tokenizer/training/datasets/augmentation_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..02166bf140754a4b05a2b883cffea9d5d98e9872 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/augmentation_provider.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Augmentations for tokenizer training (image and video)""" + + +from cosmos_predict1.tokenizer.training.datasets.augmentors import ( + CenterCrop, + CropResizeAugmentor, + HorizontalFlip, + Normalize, + RandomReverse, + ReflectionPadding, + ResizeSmallestSideAspectPreserving, + UnsqueezeImage, +) +from cosmos_predict1.tokenizer.training.datasets.utils import ( + VIDEO_KEY, + VIDEO_RES_SIZE_INFO, + VIDEO_VAL_CROP_SIZE_INFO, + get_crop_size_info, +) +from cosmos_predict1.utils import log +from cosmos_predict1.utils.lazy_config import LazyCall, LazyDict + +_PROB_OF_CROP_ONLY: float = 0.1 + + +def video_train_augmentations( + input_keys: list[str], + resolution: str = "1080", + crop_height: int = 256, +) -> dict[str, LazyDict]: + [_video_key] = input_keys + crop_sizes = get_crop_size_info(crop_height) + log.info(f"[video] training crop_height={crop_height} and crop_sizes: {crop_sizes}.") + augmentations = { + "crop_resize": LazyCall(CropResizeAugmentor)( + input_keys=[_video_key], + output_keys=[VIDEO_KEY], + crop_args={"size": crop_sizes}, + resize_args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + args={"prob": _PROB_OF_CROP_ONLY}, + ), + "random_reverse": LazyCall(RandomReverse)( + input_keys=[VIDEO_KEY], + args={"prob": 0.5}, + ), + "reflection_padding": LazyCall(ReflectionPadding)( + input_keys=[VIDEO_KEY], + args={"size": crop_sizes}, + ), + "horizontal_flip": LazyCall(HorizontalFlip)( + input_keys=[VIDEO_KEY], + args={"size": crop_sizes}, + ), + "normalize": LazyCall(Normalize)( + input_keys=[VIDEO_KEY], + args={"mean": 0.5, "std": 0.5}, + ), + "unsqueeze_padding": LazyCall(UnsqueezeImage)(input_keys=["padding_mask"]), + } + + return augmentations + + +def video_val_augmentations( + input_keys: list[str], resolution: str = "1080", crop_height: int = None +) -> dict[str, LazyDict]: + [_video_key] = input_keys + if crop_height is None: + crop_sizes = VIDEO_VAL_CROP_SIZE_INFO[resolution] + else: + crop_sizes = get_crop_size_info(crop_height) + + log.info(f"[video] validation crop_sizes: {crop_sizes}.") + augmenations = { + "resize_smallest_side_aspect_ratio_preserving": LazyCall(ResizeSmallestSideAspectPreserving)( + input_keys=[VIDEO_KEY], + args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + ), + "center_crop": LazyCall(CenterCrop)( + input_keys=[VIDEO_KEY], + args={"size": crop_sizes}, + ), + "reflection_padding": LazyCall(ReflectionPadding)( + input_keys=[VIDEO_KEY], + args={"size": crop_sizes}, + ), + "normalize": LazyCall(Normalize)( + input_keys=[VIDEO_KEY], + args={"mean": 0.5, "std": 0.5}, + ), + "unsqueeze_padding": LazyCall(UnsqueezeImage)(input_keys=["padding_mask"]), + } + return augmenations diff --git a/cosmos_predict1/tokenizer/training/datasets/augmentors.py b/cosmos_predict1/tokenizer/training/datasets/augmentors.py new file mode 100644 index 0000000000000000000000000000000000000000..f9ed4bea68655b56175324746add2ea3a8fd6e2b --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/augmentors.py @@ -0,0 +1,540 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Additional augmentors for image and video training loops.""" + +from typing import Any, Optional + +import omegaconf +import torch +import torchvision.transforms.functional as transforms_F +from loguru import logger as logging + +from cosmos_predict1.tokenizer.training.datasets.utils import obtain_augmentation_size, obtain_image_size +from cosmos_predict1.utils import log + + +class Augmentor: + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + r"""Base augmentor class + + Args: + input_keys (list): List of input keys + output_keys (list): List of output keys + args (dict): Arguments associated with the augmentation + """ + self.input_keys = input_keys + self.output_keys = output_keys + self.args = args + + def __call__(self, *args: Any, **kwds: Any) -> Any: + raise ValueError("Augmentor not implemented") + + +class LossMask(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + mask_config = self.args["masking"] + + input_key = self.input_keys[0] + default_mask = torch.ones_like(data_dict[input_key]) + loss_mask = mask_config["nonhuman_mask"] * default_mask + for curr_key in mask_config: + if curr_key not in self.input_keys: + continue + curr_mask = data_dict[curr_key] + curr_weight = mask_config[curr_key] + curr_loss_mask = curr_mask * curr_weight + (1 - curr_mask) * loss_mask + loss_mask = torch.max(curr_loss_mask, loss_mask) + _ = data_dict.pop(curr_key) + data_dict["loss_mask"] = loss_mask + return data_dict + + +class UnsqueezeImage(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs horizontal flipping. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + for key in self.input_keys: + data_dict[key] = data_dict[key].unsqueeze(1) + + return data_dict + + +class RandomReverse(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs random temporal reversing of frames. + + Args: + data_dict (dict): Input data dict, CxTxHxW + Returns: + data_dict (dict): Output dict where videos are randomly reversed. + """ + assert self.args is not None + p = self.args.get("prob", 0.5) + coin_flip = torch.rand(1).item() <= p + for key in self.input_keys: + if coin_flip: + data_dict[key] = torch.flip(data_dict[key], dims=[1]) + + return data_dict + + +class RenameInputKeys(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Rename the input keys from the data dict. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict with keys renamed. + """ + assert len(self.input_keys) == len(self.output_keys) + for input_key, output_key in zip(self.input_keys, self.output_keys): + if input_key in data_dict: + data_dict[output_key] = data_dict.pop(input_key) + return data_dict + + +class CropResizeAugmentor(Augmentor): + def __init__( + self, + input_keys: list, + output_keys: Optional[list] = None, + crop_args: Optional[dict] = None, + resize_args: Optional[dict] = None, + args: Optional[dict] = None, + ) -> None: + super().__init__(input_keys, output_keys, args) + self.crop_args = crop_args + self.resize_args = resize_args + self.crop_op = RandomCrop(input_keys, output_keys, crop_args) + self.resize_op = ResizeSmallestSideAspectPreserving(input_keys, output_keys, resize_args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs random temporal reversing of frames. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where videso are randomly reversed. + """ + assert self.args is not None + p = self.args.get("prob", 0.1) + + if p > 0.0: + crop_img_size = obtain_augmentation_size(data_dict, self.crop_args) + crop_width, crop_height = crop_img_size + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + if orig_w < crop_width or orig_h < crop_height: + log.warning( + f"Data size ({orig_w}, {orig_h}) is smaller than crop size ({crop_width}, {crop_height}), skip the crop augmentation." + ) + coin_flip = torch.rand(1).item() <= p + if coin_flip and crop_width <= orig_w and crop_height <= orig_h: + data_dict = self.crop_op(data_dict) + return data_dict + + data_dict = self.resize_op(data_dict) + data_dict = self.crop_op(data_dict) + + return data_dict + + +class CenterCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs center crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + assert (self.args is not None) and ("size" in self.args), "Please specify size in args" + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + + # We also add the aug params we use. This will be useful for other transforms + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + data_dict["padding_mask"] = torch.zeros((1, cropping_params["crop_h"], cropping_params["crop_w"])) + return data_dict + + +class RandomCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs random crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + # Obtaining random crop coords + try: + crop_x0 = int(torch.randint(0, orig_w - width + 1, size=(1,)).item()) + crop_y0 = int(torch.randint(0, orig_h - height + 1, size=(1,)).item()) + except Exception as e: + logging.warning( + f"Random crop failed. Performing center crop, original_size(wxh): {orig_w}x{orig_h}, random_size(wxh): {width}x{height}" + ) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + + # We also add the aug params we use. This will be useful for other transforms + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + + # We must perform same random cropping for all input keys + for key in self.input_keys: + data_dict[key] = transforms_F.crop(data_dict[key], crop_y0, crop_x0, height, width) + return data_dict + + +class HorizontalFlip(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs horizontal flipping. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + flip_enabled = getattr(self.args, "enabled", True) + if flip_enabled: + p = getattr(self.args, "prob", 0.5) + coin_flip = torch.rand(1).item() > p + for key in self.input_keys: + if coin_flip: + data_dict[key] = transforms_F.hflip(data_dict[key]) + + return data_dict + + +class Normalize(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + + mean = self.args["mean"] + std = self.args["std"] + + for key in self.input_keys: + if isinstance(data_dict[key], torch.Tensor): + data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255) + else: + data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor() + + data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std) + return data_dict + + +class ReflectionPadding(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs reflection padding. This function also returns a padding mask. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + + assert self.args is not None, "Please specify args in augmentation" + if self.output_keys is None: + self.output_keys = self.input_keys + + # Obtain image and augmentation sizes + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + target_size = obtain_augmentation_size(data_dict, self.args) + + assert isinstance(target_size, (tuple, omegaconf.listconfig.ListConfig)), "Please specify target size as tuple" + target_w, target_h = target_size + + target_w = int(target_w) + target_h = int(target_h) + + # Calculate padding vals + padding_left = int((target_w - orig_w) / 2) + padding_right = target_w - orig_w - padding_left + padding_top = int((target_h - orig_h) / 2) + padding_bottom = target_h - orig_h - padding_top + padding_vals = [padding_left, padding_top, padding_right, padding_bottom] + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + if max(padding_vals[0], padding_vals[2]) >= orig_w or max(padding_vals[1], padding_vals[3]) >= orig_h: + # In this case, we can't perform reflection padding. This is because padding values + # are larger than the image size. So, perform edge padding instead. + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="edge") + else: + # Perform reflection padding + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="reflect") + + if out_key != inp_key: + del data_dict[inp_key] + + # Return padding_mask when padding is performed. + # Padding mask denotes which pixels are padded. + padding_mask = torch.ones((1, target_h, target_w)) + padding_mask[:, padding_top : (padding_top + orig_h), padding_left : (padding_left + orig_w)] = 0 + data_dict["padding_mask"] = padding_mask + data_dict["image_size"] = torch.tensor([target_h, target_w, orig_h, orig_w], dtype=torch.float) + + return data_dict + + +class ResizeSmallestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to smaller side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=out_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to larger side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + + scaling_ratio = min(out_size / orig_w, out_size / orig_h) + target_size = [int(scaling_ratio * orig_h), int(scaling_ratio * orig_w)] + + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance( + img_size, (tuple, omegaconf.listconfig.ListConfig) + ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] >= img_h and target_size[1] >= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the larger ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the larger of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance( + img_size, (tuple, omegaconf.listconfig.ListConfig) + ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = min((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert ( + target_size[0] <= img_h and target_size[1] <= img_w + ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict diff --git a/cosmos_predict1/tokenizer/training/datasets/dataset_provider.py b/cosmos_predict1/tokenizer/training/datasets/dataset_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..9904bd1eba5d77d3baae9b5ff7b5b6acabba5dae --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/dataset_provider.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of dataset settings and augmentations for tokenization + +Run this command to interactively debug: +python3 -m cosmos_predict1.tokenizer.training.datasets.dataset_provider + +""" + +from cosmos_predict1.tokenizer.training.datasets.augmentation_provider import ( + video_train_augmentations, + video_val_augmentations, +) +from cosmos_predict1.tokenizer.training.datasets.utils import categorize_aspect_and_store +from cosmos_predict1.tokenizer.training.datasets.video_dataset import Dataset +from cosmos_predict1.utils.lazy_config import instantiate + +_VIDEO_PATTERN_DICT = { + "hdvila_video": "datasets/hdvila/videos/*.mp4", +} + + +def apply_augmentations(data_dict, augmentations_dict): + """ + Loop over each LazyCall object and apply it to data_dict in place. + """ + for aug_name, lazy_aug in augmentations_dict.items(): + aug_instance = instantiate(lazy_aug) + data_dict = aug_instance(data_dict) + return data_dict + + +class AugmentDataset(Dataset): + def __init__(self, base_dataset, augmentations_dict): + """ + base_dataset: the video dataset instance + augmentations_dict: the dictionary returned by + video_train_augmentations() or video_val_augmentations() + """ + self.base_dataset = base_dataset + + # Pre-instantiate every augmentation ONCE: + self.augmentations = [] + for aug_name, lazy_aug in augmentations_dict.items(): + aug_instance = instantiate(lazy_aug) # build the actual augmentation + self.augmentations.append((aug_name, aug_instance)) + + def __len__(self): + return len(self.base_dataset) + + def __getitem__(self, index): + # Get the raw sample from the base dataset + data = self.base_dataset[index] + data = categorize_aspect_and_store(data) + + # Apply each pre-instantiated augmentation + for aug_name, aug_instance in self.augmentations: + data = aug_instance(data) + + return data + + +def dataset_entry( + dataset_name: str, + dataset_type: str, + is_train: bool = True, + resolution="720", + crop_height=256, + num_video_frames=25, +) -> AugmentDataset: + if dataset_type != "video": + raise ValueError(f"Dataset type {dataset_type} is not supported") + + # Instantiate the video dataset + base_dataset = Dataset( + video_pattern=_VIDEO_PATTERN_DICT[dataset_name.lower()], + num_video_frames=num_video_frames, + ) + + # Pick the training or validation augmentations + if is_train: + aug_dict = video_train_augmentations( + input_keys=["video"], # adjust if necessary + resolution=resolution, + crop_height=crop_height, + ) + else: + aug_dict = video_val_augmentations( + input_keys=["video"], + resolution=resolution, + crop_height=crop_height, + ) + + # Wrap the dataset with the augmentations + return AugmentDataset(base_dataset, aug_dict) + + +if __name__ == "__main__": + # Example usage / quick test + dataset = dataset_entry( + dataset_name="davis_video", + dataset_type="video", + is_train=False, + resolution="720", + crop_height=256, + num_video_frames=25, + ) + + # 2) Print out some basic info: + print(f"Total samples in dataset: {len(dataset)}") + + # 3) Grab one sample (or a few) to check shapes, keys, etc. + if len(dataset) > 0: + sample_idx = 0 + sample = dataset[sample_idx] + print(f"Sample index {sample_idx} keys: {list(sample.keys())}") + if "video" in sample: + print("Video shape:", sample["video"].shape) + if "video_name" in sample: + print("Video metadata:", sample["video_name"]) + print("---\nSample loaded successfully.\n") + else: + print("Dataset has no samples!") diff --git a/cosmos_predict1/tokenizer/training/datasets/mock_dataset.py b/cosmos_predict1/tokenizer/training/datasets/mock_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fd1e1d5e67b647e1a628c4a2f891b5d95ad07980 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/mock_dataset.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Copied from jam_data. +""" + +import inspect +from typing import Any, Callable, Dict + +import torch +from torch.utils.data import Dataset + +MAX_LENGTH = 1 << 15 + + +class LambdaDataset(torch.utils.data.Dataset): + """ + A dataset that generates items by applying a function. This allows for creating + dynamic datasets where the items are the result of function calls. The function can optionally + accept an index argument. + + Attributes: + length (int): The total number of items in the dataset. + fn (Callable): The function to generate dataset items. + is_index_in_params (bool): Flag to determine whether 'index' should be passed + to the function `fn`. + """ + + def __init__(self, fn: Callable, length: int = MAX_LENGTH) -> None: + """ + Initializes the LambdaDataset with a function and the total length. + + Args: + fn (Callable): A function that returns a dataset item. It can optionally accept an + index argument to generate data items based on their index. + length (int): The total number of items in the dataset, defaults to MAX_LENGTH. + """ + self.length = length + self.fn = fn + + try: + # Attempt to inspect the function signature to determine if it accepts an 'index' parameter. + signature = inspect.signature(fn) + self.is_index_in_params = "index" in signature.parameters + except ValueError: + # If the function signature is not inspectable, assume 'index' is not a parameter. + self.is_index_in_params = False + + def __len__(self) -> int: + """ + Returns the total length of the dataset. + + Returns: + int: The number of items in the dataset. + """ + return self.length + + def __getitem__(self, index: int) -> Any: + """ + Retrieves an item at a specific index from the dataset by calling the function `fn`. + Passes the index to `fn` if `fn` is designed to accept an index. + + Args: + index (int): The index of the item to retrieve. + + Returns: + Any: The item returned by the function `fn`. + """ + if self.is_index_in_params: + return self.fn(index) # Call fn with index if it accepts an index parameter. + return self.fn() # Call fn without any parameters if it does not accept the index. + + +class RepeatDataset(torch.utils.data.Dataset): + """ + A dataset wrapper that allows repeating access to items from an underlying dataset. + + This dataset can be used to create an artificial extension of the underlying dataset + to a specified `length`. Each item from the original dataset can be accessed + repeatedly up to `num_item` times before it loops back. + + Attributes: + length (int): The total length of the dataset to be exposed. + dataset (Dataset): The original dataset. + num_item (int): Number of times each item is repeated. + cache_item (dict): Cache to store accessed items to avoid recomputation. + """ + + def __init__(self, dataset: Dataset, length: int = MAX_LENGTH, num_item: int = 1) -> None: + """ + Initializes the RepeatDataset with a dataset, length, and number of repeats per item. + + Args: + dataset (Dataset): The dataset to repeat. + length (int): The total length of the dataset to be exposed. Defaults to MAX_LENGTH. + num_item (int): The number of times to repeat each item. Defaults to 1. + """ + self.length = length + self.dataset = dataset + self.num_item = num_item + self.cache_item = {} + + def __len__(self) -> int: + return self.length + + def __getitem__(self, index: int) -> Any: + index = index % self.num_item + if index not in self.cache_item: + self.cache_item[index] = self.dataset[index] + return self.cache_item[index] + + +class CombinedDictDataset(torch.utils.data.Dataset): + """ + A dataset that wraps multiple PyTorch datasets and returns a dictionary of data items from each dataset for a given index. + This dataset ensures that all constituent datasets have the same length by setting the length to the minimum length of the datasets provided. + + Parameters: + ----------- + **datasets : Dict[str, Dataset] + A dictionary where keys are string identifiers for the datasets and values are the datasets instances themselves. + + Attributes: + ----------- + datasets : Dict[str, Dataset] + Stores the input datasets. + max_length : int + The minimum length among all provided datasets, determining the length of this combined dataset. + + Examples: + --------- + >>> dataset1 = torch.utils.data.TensorDataset(torch.randn(100, 3, 32, 32)) + >>> dataset2 = torch.utils.data.TensorDataset(torch.randn(100, 3, 32, 32)) + >>> combined_dataset = CombinedDictDataset(dataset1=dataset1, dataset2=dataset2) + >>> print(len(combined_dataset)) + 100 + >>> data = combined_dataset[50] + >>> print(data.keys()) + dict_keys(['dataset1', 'dataset2']) + """ + + def __init__(self, **datasets: Dict[str, Dataset]) -> None: + """ + Initializes the CombinedDictDataset with multiple datasets. + + Args: + **datasets (Dict[str, Dataset]): Key-value pairs where keys are dataset names and values + are dataset instances. Each key-value pair adds a dataset + under the specified key. + """ + self.datasets = datasets + self.max_length = min([len(dataset) for dataset in datasets.values()]) + + def __len__(self) -> int: + return self.max_length + + def __getitem__(self, index: int) -> Dict[str, Any]: + """ + Retrieves an item from each dataset at the specified index, combines them into a dictionary, + and returns the dictionary. Each key in the dictionary corresponds to one of the dataset names provided + during initialization, and its value is the item from that dataset at the given index. + + Args: + index (int): The index of the items to retrieve across all datasets. + + Returns: + Dict[str, Any]: A dictionary containing data items from all datasets for the given index. + Each key corresponds to a dataset name, and its value is the data item from that dataset. + """ + data = {} + for key, dataset in self.datasets.items(): + data[key] = dataset[index] + return data diff --git a/cosmos_predict1/tokenizer/training/datasets/utils.py b/cosmos_predict1/tokenizer/training/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca7c109e778d0329fa9e76b1e8b74b39b4a6972 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/utils.py @@ -0,0 +1,183 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for datasets creation.""" + +IMAGE_KEY = "images" +VIDEO_KEY = "video" +RECON_KEY = "reconstructions" +LATENT_KEY = "latent" +INPUT_KEY = "INPUT" +MASK_KEY = "loss_mask" + +_SPATIAL_ALIGN = 16 + + +import math +from typing import Union + +import torch +from PIL import Image + +# This is your "for short_side=720" map: +_ASPECT_SIZE_DICT = { + "1,1": (720, 720), + "4,3": (960, 720), + "3,4": (720, 960), + "16,9": (1280, 720), + "9,16": (720, 1280), +} + + +VIDEO_RES_SIZE_INFO: dict[str, tuple[int, int]] = { + "1080": { # 1080p doesn't have 1:1 + "4,3": (1440, 1072), + "3,4": (1072, 1440), + "16,9": (1920, 1072), + "9,16": (1072, 1920), + }, + "720": {"1,1": (720, 720), "4,3": (960, 720), "3,4": (720, 960), "16,9": (1280, 720), "9,16": (720, 1280)}, + "480": {"1,1": (480, 480), "4,3": (640, 480), "3,4": (480, 640), "16,9": (854, 480), "9,16": (480, 854)}, + "512": {"1,1": (512, 512), "4,3": (672, 512), "3,4": (512, 672), "16,9": (896, 512), "9,16": (512, 896)}, + "360": {"1,1": (320, 320), "4,3": (416, 320), "3,4": (320, 416), "16,9": (544, 320), "9,16": (320, 544)}, + "256": {"1,1": (256, 256), "4,3": (320, 256), "3,4": (256, 320), "16,9": (320, 192), "9,16": (192, 320)}, + "128": { # Note that we set res lower than 256 to the same resolution as 256 + "1,1": (256, 256), + "4,3": (320, 256), + "3,4": (256, 320), + "16,9": (448, 256), + "9,16": (256, 448), + }, +} + +VIDEO_VAL_CROP_SIZE_INFO: dict[str, tuple[int, int]] = { + "1080": { # 1080p doesn't have 1:1 + "4,3": (1424, 1072), + "3,4": (1072, 1424), + "16,9": (1904, 1072), + "9,16": (1072, 1904), + "16,10": (1715, 1072), + }, + "720": {"1,1": (704, 704), "4,3": (944, 704), "3,4": (704, 944), "16,9": (1264, 704), "9,16": (704, 1264)}, + "480": {"1,1": (464, 464), "4,3": (624, 464), "3,4": (464, 624), "16,9": (848, 464), "9,16": (464, 848)}, + "360": {"1,1": (320, 320), "4,3": (416, 320), "3,4": (320, 416), "16,9": (544, 320), "9,16": (320, 544)}, + "512": {"1,1": (512, 512), "4,3": (672, 512), "3,4": (512, 672), "16,9": (896, 512), "9,16": (512, 896)}, + "256": {"1,1": (256, 256), "4,3": (320, 256), "3,4": (256, 320), "16,9": (320, 192), "9,16": (192, 320)}, + "128": { # Note that we set res lower than 256 to the same resolution as 256 + "1,1": (256, 256), + "4,3": (320, 256), + "3,4": (256, 320), + "16,9": (320, 192), + "9,16": (192, 320), + "16,10": (410, 256), + }, +} + + +def _pick_closest_aspect_ratio(height, width): + """ + Given a video's height and width, return the closest aspect ratio key + from aspect_dict. + """ + if height == 0: + return "1,1" # fallback if something weird, to avoid div by zero + + actual_ratio = width / height + + best_key = None + min_diff = math.inf + + for ratio_key, (w_target, h_target) in _ASPECT_SIZE_DICT.items(): + # for "16,9" -> (1280, 720), ratio is 1280/720 = 1.7777... + ratio = w_target / h_target + diff = abs(actual_ratio - ratio) + if diff < min_diff: + min_diff = diff + best_key = ratio_key + + return best_key + + +def categorize_aspect_and_store(data_sample): + """ + data_sample: a dict with 'video' shaped [C,T,H,W]. + We will determine the aspect ratio, pick the closest "1,1", "4,3", etc., + and store a new dict entry. + """ + # Suppose 'video' is [C, T, H, W]. + video_tensor = data_sample["video"] + H = video_tensor.shape[-2] + W = video_tensor.shape[-1] + data_sample["aspect_ratio"] = _pick_closest_aspect_ratio(H, W) + return data_sample + + +def get_crop_size_info(crop_sz: int = 128): + aspect_ratios = [(1, 1), (4, 3), (3, 4), (16, 9), (9, 16)] + crop_sizes = dict() + for aspect_ratio in aspect_ratios: + if aspect_ratio[0] < aspect_ratio[1]: + crop_h = crop_sz // _SPATIAL_ALIGN * _SPATIAL_ALIGN + crop_w = int(crop_h * aspect_ratio[0] / aspect_ratio[1] + 0.5) + crop_w = crop_w // _SPATIAL_ALIGN * _SPATIAL_ALIGN + else: + crop_w = crop_sz // _SPATIAL_ALIGN * _SPATIAL_ALIGN + crop_h = int(crop_w * aspect_ratio[1] / aspect_ratio[0] + 0.5) + crop_h = crop_h // _SPATIAL_ALIGN * _SPATIAL_ALIGN + key = f"{aspect_ratio[0]},{aspect_ratio[1]}" + crop_sizes.update({key: (crop_w, crop_h)}) + return crop_sizes + + +def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]: + r"""Function for obtaining the image size from the data dict. + + Args: + data_dict (dict): Input data dict + input_keys (list): List of input keys + Returns: + width (int): Width of the input image + height (int): Height of the input image + """ + + data1 = data_dict[input_keys[0]] + if isinstance(data1, Image.Image): + width, height = data1.size + elif isinstance(data1, torch.Tensor): + height, width = data1.size()[-2:] + else: + raise ValueError("data to random crop should be PIL Image or tensor") + + return width, height + + +def obtain_augmentation_size(data_dict: dict, augmentor_cfg: dict) -> Union[int, tuple]: + r"""Function for obtaining size of the augmentation. + When dealing with multi-aspect ratio dataloaders, we need to + find the augmentation size from the aspect ratio of the data. + + Args: + data_dict (dict): Input data dict + augmentor_cfg (dict): Augmentor config + Returns: + aug_size (int): Size of augmentation + """ + if "__url__" in data_dict and "aspect_ratio" in data_dict["__url__"].meta.opts: + aspect_ratio = data_dict["__url__"].meta.opts["aspect_ratio"] + aug_size = augmentor_cfg["size"][aspect_ratio] + else: # Non-webdataset format + aspect_ratio = data_dict["aspect_ratio"] + aug_size = augmentor_cfg["size"][aspect_ratio] + return aug_size diff --git a/cosmos_predict1/tokenizer/training/datasets/video_dataset.py b/cosmos_predict1/tokenizer/training/datasets/video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..169a9bc92fd598f9c586dbcf62f1fc8635dedf37 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/datasets/video_dataset.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run this command to interactively debug: +PYTHONPATH=. python cosmos_predict1/tokenizer/training/datasets/video_dataset.py + +Adapted from: +https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py +""" + +import traceback +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed +from glob import glob + +import numpy as np +import torch +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from torchvision import transforms as T +from tqdm import tqdm + +from cosmos_predict1.diffusion.training.datasets.dataset_utils import ToTensorVideo + + +class Dataset(Dataset): + def __init__( + self, + video_pattern, + sequence_interval=1, + start_frame_interval=1, + num_video_frames=25, + ): + """Dataset class for loading image-text-to-video generation data. + + Args: + video_pattern (str): path/to/videos/*.mp4 + sequence_interval (int): Interval between sampled frames in a sequence + num_frames (int): Number of frames to load per sequence + video_size (list): Target size [H,W] for video frames + + Returns dict with: + - video: RGB frames tensor [T,C,H,W] + - video_name: Dict with episode/frame metadata + """ + + super().__init__() + self.video_directory_or_pattern = video_pattern + self.start_frame_interval = start_frame_interval + self.sequence_interval = sequence_interval + self.sequence_length = num_video_frames + + self.video_paths = sorted(glob(str(video_pattern))) + print(f"{len(self.video_paths)} videos in total") + + self.samples = self._init_samples(self.video_paths) + self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) + print(f"{len(self.samples)} samples in total") + self.wrong_number = 0 + self.preprocess = T.Compose( + [ + ToTensorVideo(), + ] + ) + + def __str__(self): + return f"{len(self.video_paths)} samples from {self.video_directory_or_pattern}" + + def _init_samples(self, video_paths): + samples = [] + with ThreadPoolExecutor(32) as executor: + future_to_video_path = { + executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths + } + for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): + samples.extend(future.result()) + return samples + + def _load_and_process_video_path(self, video_path): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + n_frames = len(vr) + + samples = [] + for frame_i in range(0, n_frames, self.start_frame_interval): + sample = dict() + sample["video_path"] = video_path + sample["frame_ids"] = [] + curr_frame_i = frame_i + while True: + if curr_frame_i > (n_frames - 1): + break + sample["frame_ids"].append(curr_frame_i) + if len(sample["frame_ids"]) == self.sequence_length: + break + curr_frame_i += self.sequence_interval + # make sure there are sequence_length number of frames + if len(sample["frame_ids"]) == self.sequence_length: + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def _load_video(self, video_path, frame_ids): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) + assert (np.array(frame_ids) < len(vr)).all() + assert (np.array(frame_ids) >= 0).all() + vr.seek(0) + frame_data = vr.get_batch(frame_ids).asnumpy() + return frame_data + + def _get_frames(self, video_path, frame_ids): + frames = self._load_video(video_path, frame_ids) + frames = frames.astype(np.uint8) + frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, data) + frames = self.preprocess(frames) + frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) + return frames + + def __getitem__(self, index): + try: + sample = self.samples[index] + video_path = sample["video_path"] + frame_ids = sample["frame_ids"] + + data = dict() + + video = self._get_frames(video_path, frame_ids) + video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] + data["video"] = video + data["video_name"] = { + "video_path": video_path, + "start_frame_id": str(frame_ids[0]), + } + data["fps"] = 24 + data["image_size"] = torch.tensor([704, 1280, 704, 1280]) # .cuda() # TODO: Does this matter? + data["num_frames"] = self.sequence_length + data["padding_mask"] = torch.zeros(1, 704, 1280) # .cuda() + + return data + except Exception: + warnings.warn( + f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " + f"(by randomly sampling another sample in the same dataset)." + ) + warnings.warn("FULL TRACEBACK:") + warnings.warn(traceback.format_exc()) + self.wrong_number += 1 + print(self.wrong_number) + return self[np.random.randint(len(self.samples))] + + +if __name__ == "__main__": + dataset = Dataset( + video_directory_or_pattern="assets/example_training_data/videos/*.mp4", + sequence_interval=1, + num_frames=57, + video_size=[240, 360], + ) + + indices = [0, 13, 200, -1] + for idx in indices: + data = dataset[idx] + print((f"{idx=} " f"{data['video'].sum()=}\n" f"{data['video'].shape=}\n" f"{data['video_name']=}\n" "---")) diff --git a/cosmos_predict1/tokenizer/training/jit_cli.py b/cosmos_predict1/tokenizer/training/jit_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..6520a2afbb7ff2c4b024766cab652b4475f5bab4 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/jit_cli.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A CLI to export an pre-trained tokenizer checkpoint into a torch.ScriptModule. + +Usage: +python3 -m cosmos_predict1.tokenizer.training.jit_cli \ + --ckpt_path=checkpoints/Cosmos-0.1-Tokenizer-CV4x8x8/iter_001000000.pt \ + --output_dir=checkpoints/Cosmos-0.1-Tokenizer-CV4x8x8/exported \ + --strict_resume \ + --config=cosmos_predict1/tokenizer/training/configs/config.py -- \ + experiment=CV720_Causal_AE49_4x8x8_cosmos + + + will output: + /iter_001000000_ema.jit + /iter_001000000_enc.jit + /iter_001000000_dec.jit + + if --reg is specified, it will export the regular model: + /iter_001000000_reg.jit + /iter_001000000_enc.jit + /iter_001000000_dec.jit +""" + +import argparse +import importlib +import os + +import torch +from loguru import logger as logging +from torch._dynamo.eval_frame import OptimizedModule as torch_OptimizedModule + +from cosmos_predict1.tokenizer.training.checkpointer import TokenizerCheckpointer +from cosmos_predict1.utils import callback, ema +from cosmos_predict1.utils.config import Config +from cosmos_predict1.utils.config_helper import get_config_module, override +from cosmos_predict1.utils.lazy_config import instantiate +from cosmos_predict1.utils.model import Model + +parser = argparse.ArgumentParser(description="Export a pre-trained model into a torch.jit.ScriptModule.") +parser.add_argument( + "--config", type=str, default="cosmos_predict1/tokenizer/training/configs/config.py", help="Path to the config file" +) +parser.add_argument("--ckpt_path", type=str, default=None, help="The full ckpt path.") +parser.add_argument("--credentials", type=str, default="credentials/pdx_vfm_base.secret", help="The credentials file.") +parser.add_argument("--strict_resume", action="store_true", help="Enable strictly loading into every network weight.") +parser.add_argument("--reg", action="store_true", help="Enable regular model export.") +parser.add_argument("--output_dir", type=str, default=None, help="Optional output directory.") + +parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, +) + +logging.info("Initialize args, cfg from command line arguments ...") +args = parser.parse_args() +config_module = get_config_module(args.config) +config: Config = importlib.import_module(config_module).make_config() +config = override(config, args.opts) + + +def _compile_jit_models(model: Model) -> dict[str, torch.ScriptModule]: + """Returns a TorchScript version of REG or EMA models compiled by PyTorch JIT.""" + assert hasattr(config, "checkpoint") and hasattr(config.checkpoint, "jit") + config_jit = config.checkpoint.jit + input_shape = tuple(config_jit.input_shape) + example_input = torch.randn(input_shape) + dtype = getattr(torch, config_jit.dtype) + example_input = example_input.to(config_jit.device).to(dtype) + + # Make sure jit model output consistenly during consecutive calls + # Check here: https://github.com/pytorch/pytorch/issues/74534 + torch._C._jit_set_texpr_fuser_enabled(False) + + with ema.ema_scope(model, enabled=model.config.ema.enabled and not args.reg): + _model = model.network.eval() + if isinstance(_model, torch_OptimizedModule): + _model = _model._orig_mod + model_jit = torch.jit.trace(_model, example_input, strict=config_jit.strict) + encoder_jit = torch.jit.trace(_model.encoder_jit(), example_input, strict=config_jit.strict) + decoder_example = encoder_jit(example_input)[0] + decoder_jit = torch.jit.trace(_model.decoder_jit(), decoder_example, strict=config_jit.strict) + if args.reg: + return {"reg": model_jit, "enc": encoder_jit, "dec": decoder_jit} + return {"ema": model_jit, "enc": encoder_jit, "dec": decoder_jit} + + +def _run_export() -> None: + """Exports a torch.nn.Module into a torch.jit.ScriptModule.""" + # Check that the config is valid. + config.validate() + config.checkpoint.load_path = args.ckpt_path + config.checkpoint.strict_resume = args.strict_resume + config.checkpoint.load_training_state = False + config.job.name = os.path.basename(args.output_dir) if args.output_dir else os.path.basename(args.ckpt_path) + + # Freeze the config. + config.freeze() # type: ignore + callbacks = callback.CallBackGroup(config=config, trainer=None) + checkpointer = TokenizerCheckpointer(config.checkpoint, config.job, callbacks=callbacks) + + # Create the model. + logging.info(f"Instantiate model={config.model.config.network.name} ...") + model = instantiate(config.model) + model = model.to("cuda", memory_format=config.trainer.memory_format) # type: ignore + model.on_train_start(config.trainer.memory_format) + + logging.info(f"loading weights from {config.checkpoint.load_path}...") + _ = checkpointer.load(model) + model.eval() + ckpt_name = config.checkpoint.load_path.split("/")[-1][:-3] + + # Drive the output directory. + tmp_output_dir = os.path.dirname(config.checkpoint.load_path) + output_dir = args.output_dir or tmp_output_dir + os.makedirs(output_dir, exist_ok=True) + + logging.info("Performing JIT compilation ...") + jit_models = _compile_jit_models(model) + for name, jit_model in jit_models.items(): + logging.info(f"Outputing torch.jit: {output_dir}/{ckpt_name}_{name}.jit") + torch.jit.save(jit_model, f"{output_dir}/{ckpt_name}_{name}.jit") + + +@logging.catch(reraise=True) +def main() -> None: + _run_export() + + +if __name__ == "__main__": + main() diff --git a/cosmos_predict1/tokenizer/training/losses/__init__.py b/cosmos_predict1/tokenizer/training/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f946d23712a17a93be7fc9668985dc7cf070b01c --- /dev/null +++ b/cosmos_predict1/tokenizer/training/losses/__init__.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The loss reduction modes.""" + +from enum import Enum + +import torch + + +def _mean(recon: torch.Tensor) -> torch.Tensor: + return torch.mean(recon) + + +def _sum_per_frame(recon: torch.Tensor) -> torch.Tensor: + batch_size = recon.shape[0] * recon.shape[2] if recon.ndim == 5 else recon.shape[0] + return torch.sum(recon) / batch_size + + +def _sum(recon: torch.Tensor) -> torch.Tensor: + return torch.sum(recon) / recon.shape[0] + + +class ReduceMode(Enum): + MEAN = "MEAN" + SUM_PER_FRAME = "SUM_PER_FRAME" + SUM = "SUM" + + @property + def function(self): + if self == ReduceMode.MEAN: + return _mean + elif self == ReduceMode.SUM_PER_FRAME: + return _sum_per_frame + elif self == ReduceMode.SUM: + return _sum + else: + raise ValueError("Invalid ReduceMode") diff --git a/cosmos_predict1/tokenizer/training/losses/continuous.py b/cosmos_predict1/tokenizer/training/losses/continuous.py new file mode 100644 index 0000000000000000000000000000000000000000..86178374dd737d61f5f0783bf5e781cc261a1734 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/losses/continuous.py @@ -0,0 +1,479 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The combined loss functions for continuous-space tokenizers training.""" +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import torchvision.models.optical_flow as optical_flow + +from cosmos_predict1.tokenizer.modules.utils import batch2time, time2batch +from cosmos_predict1.tokenizer.training.datasets.utils import INPUT_KEY, LATENT_KEY, MASK_KEY, RECON_KEY +from cosmos_predict1.tokenizer.training.losses import ReduceMode +from cosmos_predict1.tokenizer.training.losses.lpips import LPIPS +from cosmos_predict1.utils.lazy_config import instantiate + +_VALID_LOSS_NAMES = ["color", "perceptual", "flow", "kl", "video_consistency"] +VIDEO_CONSISTENCY_LOSS = "video_consistency" +RECON_CONSISTENCY_KEY = f"{RECON_KEY}_consistency" + + +class TokenizerLoss(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + _reduce = ReduceMode(config.reduce.upper()) if hasattr(config, "reduce") else None + self.reduce = _reduce.function + self.loss_modules = nn.ModuleDict() + for key in _VALID_LOSS_NAMES: + self.loss_modules[key] = instantiate(getattr(config, key)) if hasattr(config, key) else NullLoss() + + def forward(self, inputs, output_batch, iteration) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + loss = dict() + total_loss = 0.0 + + inputs[MASK_KEY] = torch.ones_like(inputs[INPUT_KEY]) + # Calculates reconstruction losses (`total_loss`). + for key, module in self.loss_modules.items(): + curr_loss = module(inputs, output_batch, iteration) + loss.update({k: torch.mean(v) for k, v in curr_loss.items()}) + total_loss += sum([self.reduce(v) if (v.dim() > 0) else v for v in curr_loss.values()]) + + loss.update({k: torch.mean(v) for k, v in curr_loss.items()}) + + # Computes the overall loss as sum of the reconstruction losses and the generator loss. + total_loss += sum([self.reduce(v) if (v.dim() > 0) else v for v in curr_loss.values()]) + return dict(loss=loss), total_loss + + +class WeightScheduler(torch.nn.Module): + def __init__(self, boundaries, values): + super().__init__() + self.boundaries = list(boundaries) + self.values = list(values) + + def forward(self, iteration): + for boundary, value in zip(self.boundaries, self.values): + if iteration < boundary: + return value + return self.values[-1] + + +class NullLoss(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, inputs, output_batch, iteration) -> dict[dict, torch.Tensor]: + return dict() + + +class ColorLoss(torch.nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values) + + def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]: + reconstructions = output_batch[RECON_KEY] + weights = inputs[MASK_KEY] + recon = weights * torch.abs(inputs[INPUT_KEY].contiguous() - reconstructions.contiguous()) + color_weighted = self.schedule(iteration) * recon + if torch.isnan(color_weighted).any(): + raise ValueError("[COLOR] NaN detected in loss") + return dict(color=color_weighted) + + +class KLLoss(torch.nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values) + + def kl(self, mean, logvar): + _dims = [idx for idx in range(1, mean.ndim)] + var = torch.exp(logvar) + return 0.5 * (torch.pow(mean, 2) + var - 1.0 - logvar) + + def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]: + if "posteriors" not in output_batch: # No KL loss for discrete tokens. + return dict() + mean, logvar = output_batch["posteriors"] + if mean.ndim == 1: # No KL if the mean is a scalar. + return dict() + kl = self.kl(mean, logvar) + kl_weighted = self.schedule(iteration) * kl + if torch.isnan(kl_weighted).any(): + raise ValueError("[KL] NaN detected in loss") + return dict(kl=kl_weighted) + + +class PerceptualLoss(LPIPS): + """Relevant changes that're internal to us: + + - Remove linear projection layers, simply use the raw pre-normalized features. + - Use pyramid-layer weights: [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5]. + - Accepts pixel-wise masks and modulates the features before norm calculation. + - Implements gram-matrix and correlation losses. + """ + + def __init__(self, config): + super(PerceptualLoss, self).__init__(config.checkpoint_activations) + self.net = self.net.eval() + self.gram_enabled = config.gram_enabled + self.corr_enabled = config.corr_enabled + self.layer_weights = list(config.layer_weights) + self.lpips_schedule = WeightScheduler(config.lpips_boundaries, config.lpips_values) + self.gram_schedule = WeightScheduler(config.gram_boundaries, config.gram_values) + self.corr_schedule = WeightScheduler(config.corr_boundaries, config.corr_values) + self.checkpoint_activations = config.checkpoint_activations + + def _temporal_gram_matrix(self, x, batch_size=None): + x = batch2time(x, batch_size) + c, t, h, w = x.shape[-4], x.shape[-3], x.shape[-2], x.shape[-1] + reshaped_x = torch.reshape(x, [-1, c, t * h * w]) + return torch.matmul(reshaped_x, reshaped_x.transpose(1, 2)) / float(t * h * w) + + def _gram_matrix(self, x, batch_size=None): + if batch_size is not None and x.shape[0] != batch_size: + return self._temporal_gram_matrix(x, batch_size) + c, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + reshaped_x = torch.reshape(x, [-1, c, h * w]) + return torch.matmul(reshaped_x, reshaped_x.transpose(1, 2)) / float(h * w) + + def forward(self, inputs, output_batch, iteration): + output_dict = dict() + reconstructions = output_batch[RECON_KEY] + weights = inputs[MASK_KEY] + input_images = inputs[INPUT_KEY] + + if input_images.ndim == 5: + input_images, batch_size = time2batch(input_images) + reconstructions, _ = time2batch(reconstructions) + weights, _ = time2batch(weights) + else: + batch_size = input_images.shape[0] + + in0_input, in1_input = (self.scaling_layer(input_images), self.scaling_layer(reconstructions)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + + _layer_weights = self.layer_weights + weights_map, res, diffs = {}, {}, {} + for kk in range(len(self.chns)): + weights_map[kk] = torch.nn.functional.interpolate(weights[:, :1, ...], outs0[kk].shape[-2:]) + diffs[kk] = weights_map[kk] * torch.abs(outs0[kk] - outs1[kk]) + res[kk] = _layer_weights[kk] * diffs[kk].mean([1, 2, 3], keepdim=True) + + val = res[0] + for ll in range(1, len(self.chns)): + val += res[ll] + # Scale by number of pixels to match pixel-wise losses. + val = val.expand(-1, input_images.shape[-3], input_images.shape[-2], input_images.shape[-1]) + if batch_size != input_images.shape[0]: + val = batch2time(val, batch_size) + if torch.isnan(val).any(): + raise ValueError("[LPIPS] NaN detected in loss") + output_dict["lpips"] = self.lpips_schedule(iteration) * val + + if self.gram_enabled and self.gram_schedule(iteration) > 0.0: + num_chans = len(self.chns) + grams0 = [self._gram_matrix(weights_map[kk] * outs0[kk], batch_size) for kk in range(num_chans)] + grams1 = [self._gram_matrix(weights_map[kk] * outs1[kk], batch_size) for kk in range(num_chans)] + gram_diffs = [(grams0[kk] - grams1[kk]) ** 2 for kk in range(num_chans)] + grams_res = [_layer_weights[kk] * gram_diffs[kk].mean([1, 2], keepdim=True) for kk in range(num_chans)] + gram_val = grams_res[0] + for ll in range(1, len(self.chns)): + gram_val += grams_res[ll] + + # Scale by number of total pixels to match pixel-wise losses. + gram_val = gram_val.unsqueeze(1).expand( + -1, input_images.shape[-3], input_images.shape[-2], input_images.shape[-1] + ) + if batch_size != input_images.shape[0]: + gram_val = batch2time(gram_val, batch_size) + if torch.isnan(gram_val).any(): + raise ValueError("[GRAM] NaN detected in loss") + output_dict["gram"] = self.gram_schedule(iteration) * gram_val + return output_dict + + def torch_compile(self): + """ + This method invokes torch.compile() on this loss + """ + # cuda-graphs crash after 1k iterations + self.net = torch.compile(self.net, dynamic=False) + + +class FlowLoss(torch.nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.schedule = WeightScheduler(config.boundaries, config.values) + self.scale = config.scale + self.dtype = getattr(torch, config.dtype) + self.checkpoint_activations = config.checkpoint_activations + self.enabled = config.enabled + + current_device = torch.device(torch.cuda.current_device()) + + # In order to be able to run model in bf16 we need to change make_coords_grid() + # to allow it to return arbitrary type provided by us in argument + # the line from orginal implementation that caused results to be only fp32 is commented + # Additionally I've changed that function to run on GPU instead of CPU, which results in + # less graph breaks when torch.compile() is used + # This function is copied from + # https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/_utils.py#L22 + # commit: b06ea39d5f0adbe949d08257837bda912339e415 + def make_coords_grid( + batch_size: int, h: int, w: int, device: torch.device = current_device, dtype: torch.dtype = self.dtype + ): + # Original: def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"): + device = torch.device(device) + coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij") + coords = torch.stack(coords[::-1], dim=0).to(dtype) + # Original: coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch_size, 1, 1, 1) + + # We also need to specify output dtype of torch.linspace() in index_pyramid() + # method of CorrBlock, otherwise it uses default fp32 dtype as output. + # Additionally I've changed that function to run on GPU instead of CPU, which results in + # less graph breaks when torch.compile() is used + # This function is copied from + # https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py#L394 + # commit: b06ea39d5f0adbe949d08257837bda912339e415 + def index_pyramid( + self, centroids_coords, dtype: torch.dtype = self.dtype, device: torch.device = current_device + ): + # Original: def index_pyramid(self, centroids_coords): + """Return correlation features by indexing from the pyramid.""" + neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels + di = torch.linspace(-self.radius, self.radius, neighborhood_side_len, dtype=dtype, device=device) + dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len, dtype=dtype, device=device) + # Original: di = torch.linspace(-self.radius, self.radius, neighborhood_side_len) + # Original: dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len) + delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device) + delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2) + + batch_size, _, h, w = centroids_coords.shape # _ = 2 + centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2) + + indexed_pyramid = [] + for corr_volume in self.corr_pyramid: + sampling_coords = centroids_coords + delta # end shape is (batch_size * h * w, side_len, side_len, 2) + indexed_corr_volume = optical_flow.raft.grid_sample( + corr_volume, sampling_coords, align_corners=True, mode="bilinear" + ).view(batch_size, h, w, -1) + indexed_pyramid.append(indexed_corr_volume) + centroids_coords = centroids_coords / 2 + + corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous() + + expected_output_shape = (batch_size, self.out_channels, h, w) + if corr_features.shape != expected_output_shape: + raise ValueError( + f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}" + ) + + return corr_features + + optical_flow.raft.make_coords_grid = make_coords_grid + optical_flow.raft.CorrBlock.index_pyramid = index_pyramid + + flow_model = optical_flow.raft_large(pretrained=True, progress=False) + flow_model.requires_grad_(False) + flow_model.eval() + flow_model = flow_model.to(self.dtype) + + self.flow_model = flow_model + + def _run_model(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: + """Runs flow_model in the forward mode on explicit dtype=float32. + + Args: + input1: First video frames batch, layout (T, C, H, W), bfloat16. + input2: Next video frames batch, layout (T, C, H, W), bfloat16. + + Returns: + Forward optical flow, (T, 2, H, W), bfloat16. + """ + input_dtype = input1.dtype + flow_output = self.flow_model.to(self.dtype)(input1.to(self.dtype), input2.to(self.dtype))[-1] + return flow_output.to(input_dtype) + + def _run_model_fwd(self, input_video: torch.Tensor) -> torch.Tensor: + """Runs foward flow on a batch of videos, one batch at a time. + Args: + input_video: The input batch of videos, layout (B, T, C, H, W). + + Returns: + Forward optical flow, layout (B, 2, T-1, H, W). + """ + output_list = list() + for fwd_input_frames in input_video: + fwd_input_frames = fwd_input_frames.transpose(1, 0) + fwd_flow_output = self._run_model(fwd_input_frames[:-1], fwd_input_frames[1:]) + output_list.append(fwd_flow_output.transpose(1, 0)) + return torch.stack(output_list, dim=0) + + def _bidirectional_flow(self, input_video: torch.Tensor) -> torch.Tensor: + """The bidirectional optical flow on a batch of videos. + + The forward and backward flows are averaged to get the bidirectional flow. + To reduce memory pressure, the input video is scaled down by a factor of `self.scale`, + and rescaled back to match other pixel-wise losses. + + Args: + input_video: The input batch of videos, layout (B, T, C, H, W). + + Returns: + Biderectinoal flow, layout (B, 2, T-1, H, W). + """ + # scale down the input video to reduce memory pressure. + t, h, w = input_video.shape[-3:] + input_video_scaled = F.interpolate(input_video, (t, h // self.scale, w // self.scale), mode="trilinear") + + # forward flow. + if self.checkpoint_activations: + fwd_flow_output = checkpoint.checkpoint(self._run_model_fwd, input_video_scaled, use_reentrant=False) + else: + fwd_flow_output = self._run_model_fwd(input_video_scaled) + + # backward flow. + input_video_scaled = input_video_scaled.flip([2]) + if self.checkpoint_activations: + bwd_flow_output = checkpoint.checkpoint(self._run_model_fwd, input_video_scaled, use_reentrant=False) + else: + bwd_flow_output = self._run_model_fwd(input_video_scaled) + bwd_flow_output = bwd_flow_output.flip([2]) + + # bidirectional flow, concat fwd and bwd along temporal axis. + flow_input = torch.cat([fwd_flow_output, bwd_flow_output], dim=2) + return self.scale * F.interpolate(flow_input, (2 * (t - 1), h, w), mode="trilinear") + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + input_images = inputs[INPUT_KEY] + if input_images.ndim == 4 or input_images.shape[2] == 1: + return dict() + if not self.enabled or self.schedule(iteration) == 0.0: + return dict() + + # Biderectional flow (B, 2, 2*(T-1), H, W) + flow_input = self._bidirectional_flow(input_images) + flow_recon = self._bidirectional_flow(output_batch[RECON_KEY]) + + # L1 loss on the flow. (B, 1, 2*(T-1), H, W) + flow_loss = torch.abs(flow_input - flow_recon).mean(dim=1, keepdim=True) + + flow_loss_weighted = self.schedule(iteration) * flow_loss + if torch.isnan(flow_loss_weighted).any(): + raise ValueError("[FLOW] NaN detected in loss") + return dict(flow=flow_loss_weighted) + + def torch_compile(self): + """ + This method invokes torch.compile() on this loss + """ + self.flow_model = torch.compile(self.flow_model, dynamic=False) + + +class VideoConsistencyLoss(torch.nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.schedule = WeightScheduler(boundaries=config.boundaries, values=config.values) + self.enabled = config.enabled + self.num_frames = config.num_frames + self.step = config.step + self.num_windows = None + + def shuffle(self, inputs: torch.Tensor) -> torch.Tensor: + """ + For input video of [B, 3, T, H, W], this function will reshape the video to + the shape of [B*(T-num_frames+1)//step, 3, num_frames, H, W] using a sliding window + This function is used to compute the temporal consistency between overlapped frames + to enable temporal consistency + """ + assert len(inputs.shape) == 5, f"inputs shape should be [B, 3, T, H, W]. currently {inputs.shape}" + B, C, T, H, W = inputs.shape + assert T >= self.num_frames, f"inputs {T} should be greater than {self.num_frames}" + + # [B, C, num_windows, H, W, num_frames] + outputs = inputs.unfold(dimension=2, size=self.num_frames, step=self.step) + self.num_windows = outputs.shape[2] + outputs = einops.rearrange(outputs, "b c m h w n -> (b m) c n h w") + + return outputs + + def forward(self, inputs, output_batch, iteration) -> dict[str, torch.Tensor]: + if not self.enabled or self.num_windows is None: + return dict() + if self.schedule(iteration) == 0.0: + return dict() + # reshape output_batch to compute loss between overlapped frames + reconstructions = output_batch[RECON_CONSISTENCY_KEY] + B, C, T, H, W = reconstructions.shape + + assert T == self.num_frames, f"reconstruction shape invalid (shape[2] should be {self.num_frames})" + assert ( + B % self.num_windows == 0 + ), f"reconstruction shape invalid (shape[0]={B} not dividable by {self.num_windows})" + + B = B // self.num_windows + videos = reconstructions.view(B, self.num_windows, C, self.num_frames, H, W) + + # Compute the L1 distance between overlapped frames for all windows at once + diff = torch.mean(torch.abs(videos[:, :-1, :, self.step :, :, :] - videos[:, 1:, :, : -self.step, :, :])) + diff_weighted = self.schedule(iteration) * diff + + if LATENT_KEY not in output_batch: + return dict(frame_consistency=diff_weighted) + + B_latent, C_latent, T_latent, H_latent, W_latent = output_batch["latent"].shape + assert B_latent % self.num_windows == 0, f"latent batches should be divisible by {self.num_windows}" + + latents = output_batch[LATENT_KEY].view( + B_latent // self.num_windows, self.num_windows, C_latent, T_latent, H_latent, W_latent + ) + temporal_rate = self.num_frames // T_latent + spatial_rate = (H // H_latent) * (W // W_latent) + step_latent = self.step // temporal_rate + latent_diff = torch.mean( + torch.abs(latents[:, :-1, :, step_latent:, :, :] - latents[:, 1:, :, :-step_latent, :, :]) + ) + latent_diff_weighted = self.schedule(iteration) * latent_diff * (C * temporal_rate * spatial_rate) / (C_latent) + return dict(frame_consistency=diff_weighted, latent_consistency=latent_diff_weighted) + + def unshuffle(self, inputs: torch.Tensor) -> torch.Tensor: + """ + For input video of [B*num_windows, 3, num_frames, H, W], this function will + undo the shuffle to a tensor of shape [B, 3, T, H, W] + """ + assert len(inputs.shape) == 5, f"inputs shape should be [B, 3, T, H, W]. currently {inputs.shape}" + B, C, T, H, W = inputs.shape + assert T == self.num_frames, f"inputs shape invalid (shape[2] should be {self.num_frames})" + assert B % self.num_windows == 0, f"inputs shape invalid (shape[0]={B} not dividable by {self.num_windows})" + + B = B // self.num_windows + videos = inputs.view(B, self.num_windows, C, self.num_frames, H, W) + + T = self.num_frames + (self.num_windows - 1) * self.step + current_device = torch.device(torch.cuda.current_device()) + outputs = torch.zeros(B, C, T, H, W).to(inputs.dtype).to(current_device) + counter = torch.zeros_like(outputs) + for i in range(self.num_windows): + outputs[:, :, i * self.step : i * self.step + self.num_frames, :, :] += videos[:, i, :, :, :, :] + counter[:, :, i * self.step : i * self.step + self.num_frames, :, :] += 1 + outputs = outputs / counter + + return outputs diff --git a/cosmos_predict1/tokenizer/training/losses/lpips.py b/cosmos_predict1/tokenizer/training/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..b866177d11d6eef033529b32679a087dbfe102ab --- /dev/null +++ b/cosmos_predict1/tokenizer/training/losses/lpips.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LPIPS loss. + +Adapted from: github.com/CompVis/stable-diffusion/ldm/modules/losses/contperceptual.py. +""" + +import hashlib +import os +from collections import namedtuple + +import requests +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from loguru import logger as logging +from torchvision import models +from tqdm import tqdm + +from cosmos_predict1.utils.distributed import is_rank0 + +_TORCH_HOME = os.getenv("TORCH_HOME", "/mnt/workspace/.cache/torch") +_URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} +_CKPT_MAP = {"vgg_lpips": "vgg.pth"} +_MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} + + +def _download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def _md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def _get_ckpt_path(name, root, check=False): + assert name in _URL_MAP + path = os.path.join(root, _CKPT_MAP[name]) + if not os.path.exists(path) or (check and not _md5_hash(path) == _MD5_MAP[name]): + logging.info("Downloading {} model from {} to {}".format(name, _URL_MAP[name], path)) + _download(_URL_MAP[name], path) + md5 = _md5_hash(path) + assert md5 == _MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + def __init__(self, checkpoint_activations: bool = False): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False, checkpoint_activations=checkpoint_activations) + + if dist.is_initialized() and not is_rank0(): + dist.barrier() + self.load_from_pretrained() + if dist.is_initialized() and is_rank0(): + dist.barrier() + + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = _get_ckpt_path(name, f"{_TORCH_HOME}/hub/checkpoints") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + logging.info("Loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = _get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [diffs[kk].mean([1, 2, 3], keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None], persistent=False) + self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None], persistent=False) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, checkpoint_activations: bool = False): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.checkpoint_activations = checkpoint_activations + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + if self.checkpoint_activations: + h = checkpoint.checkpoint(self.slice1, X, use_reentrant=False) + else: + h = self.slice1(X) + h_relu1_2 = h + + if self.checkpoint_activations: + h = checkpoint.checkpoint(self.slice2, h, use_reentrant=False) + else: + h = self.slice2(h) + h_relu2_2 = h + + if self.checkpoint_activations: + h = checkpoint.checkpoint(self.slice3, h, use_reentrant=False) + else: + h = self.slice3(h) + h_relu3_3 = h + + if self.checkpoint_activations: + h = checkpoint.checkpoint(self.slice4, h, use_reentrant=False) + else: + h = self.slice4(h) + h_relu4_3 = h + + if self.checkpoint_activations: + h = checkpoint.checkpoint(self.slice5, h, use_reentrant=False) + else: + h = self.slice5(h) + h_relu5_3 = h + + vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out diff --git a/cosmos_predict1/tokenizer/training/metrics.py b/cosmos_predict1/tokenizer/training/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..c541999a1de5cb44cf1337933b9827d473a84ab3 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/metrics.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The combined loss functions for continuous-space tokenizers training.""" + +import numpy as np +import torch +import torch.nn as nn +from skimage.metrics import structural_similarity as ssim + +from cosmos_predict1.tokenizer.modules.utils import time2batch +from cosmos_predict1.utils.lazy_config import instantiate + +_VALID_METRIC_NAMES = ["PSNR", "SSIM", "CodeUsage"] +_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) +_FLOAT32_EPS = torch.finfo(torch.float32).eps +_RECONSTRUCTION = "reconstructions" +_QUANT_INFO = "quant_info" + + +class TokenizerMetric(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.metric_modules = nn.ModuleDict() + for key in _VALID_METRIC_NAMES: + self.metric_modules[key] = instantiate(getattr(config, key)) if hasattr(config, key) else NULLMetric() + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + metric = dict() + for _, module in self.metric_modules.items(): + metric.update(module(inputs, output_batch, iteration)) + return dict(metric=metric) + + +class NULLMetric(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + return dict() + + +class PSNRMetric(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + reconstructions = output_batch[_RECONSTRUCTION] + if inputs.ndim == 5: + inputs, _ = time2batch(inputs) + reconstructions, _ = time2batch(reconstructions) + + # Normalize to uint8 [0..255] range. + true_image = (inputs.to(torch.float32) + 1) / 2 + pred_image = (reconstructions.to(torch.float32) + 1) / 2 + true_image = (true_image * _UINT8_MAX_F + 0.5).to(torch.uint8) + pred_image = (pred_image * _UINT8_MAX_F + 0.5).to(torch.uint8) + + # Calculate PNSR, based on Mean Squared Error (MSE) + true_image = true_image.to(torch.float32) + pred_image = pred_image.to(torch.float32) + mse = torch.mean((true_image - pred_image) ** 2, dim=(1, 2, 3)) + psnr = 10 * torch.log10(_UINT8_MAX_F**2 / (mse + _FLOAT32_EPS)) + return dict(PSNR=torch.mean(psnr)) + + +class SSIMMetric(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + reconstructions = output_batch[_RECONSTRUCTION] + if inputs.ndim == 5: + inputs, _ = time2batch(inputs) + reconstructions, _ = time2batch(reconstructions) + + # Normalize to uint8 [0..255] range. + true_image = (inputs.to(torch.float32) + 1) / 2 + pred_image = (reconstructions.to(torch.float32) + 1) / 2 + true_image = (true_image * _UINT8_MAX_F + 0.5).to(torch.uint8) + pred_image = (pred_image * _UINT8_MAX_F + 0.5).to(torch.uint8) + + # Move tensors to CPU and convert to numpy arrays + true_image_np = true_image.permute(0, 2, 3, 1).cpu().numpy() + pred_image_np = pred_image.permute(0, 2, 3, 1).cpu().numpy() + + # Calculate SSIM for each image in the batch and average over the batch + ssim_values = [] + for true_image_i, pred_image_i in zip(true_image_np, pred_image_np): + ssim_value = ssim(true_image_i, pred_image_i, data_range=_UINT8_MAX_F, multichannel=True, channel_axis=-1) + ssim_values.append(ssim_value) + ssim_mean = np.mean(ssim_values) + return dict(SSIM=torch.tensor(ssim_mean, dtype=torch.float32, device=inputs.device)) + + +class CodeUsageMetric(torch.nn.Module): + """ + Calculate the perplexity of codebook usage (only for discrete tokenizers) + + :param codebook_indices: Tensor of codebook indices (quant_info) + :param codebook_size: The total number of codebook entries + :return: Perplexity of the codebook usage + """ + + def __init__(self, codebook_size: int) -> None: + super().__init__() + self.codebook_size = codebook_size + + def forward( + self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int + ) -> dict[str, torch.Tensor]: + code_indices = output_batch[_QUANT_INFO] + usage_counts = torch.bincount(code_indices.flatten().int(), minlength=self.codebook_size) + total_usage = usage_counts.sum().float() + usage_probs = usage_counts.float() / total_usage + entropy = -torch.sum(usage_probs * torch.log(usage_probs + _FLOAT32_EPS)) + perplexity = torch.exp(entropy) + return dict(CodeUsage=torch.tensor(perplexity, dtype=torch.float32, device=code_indices.device)) diff --git a/cosmos_predict1/tokenizer/training/model.py b/cosmos_predict1/tokenizer/training/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1a75630d01950abcdf0f7247fd8868a3a612694c --- /dev/null +++ b/cosmos_predict1/tokenizer/training/model.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements the forward op for training, validation, and inference.""" + +from typing import Any + +import torch + +from cosmos_predict1.tokenizer.training.datasets.utils import IMAGE_KEY, INPUT_KEY, MASK_KEY, RECON_KEY, VIDEO_KEY +from cosmos_predict1.tokenizer.training.losses.continuous import RECON_CONSISTENCY_KEY, VIDEO_CONSISTENCY_LOSS +from cosmos_predict1.utils import ema +from cosmos_predict1.utils.lazy_config import LazyDict, instantiate +from cosmos_predict1.utils.model import Model + +PREDICTION = "prediction" +EMA_PREDICTION = "ema_prediction" + + +class TokenizerModel(Model): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.network = instantiate(config.network) + self.loss = instantiate(config.loss) + self.metric = instantiate(config.metric) + self.precision = getattr(torch, config.precision) + if self.config.ema.enabled: + self.ema = ema.EMAModelTracker( + self, + beta=self.config.ema.beta, + torch_compile_buffer_renaming=self.config.ema.torch_compile_buffer_renaming, + ) + self.init_input_keys() + + def init_input_keys(self): + self.image_key = IMAGE_KEY + self.video_key = VIDEO_KEY + + def get_input_key(self, data_batch: dict[str, torch.Tensor]) -> str: + if self.image_key in data_batch: + return self.image_key + elif self.video_key in data_batch: + return self.video_key + else: + raise ValueError("Input key not found in data_batch.") + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + """Creates the optimizer and scheduler for the network. + + Args: + optimizer_config: The optimizer config for the net. + scheduler_config: The scheduler config for the net. + + Returns: + optimizer (torch.optim.Optimizer): The net optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The net optimization scheduler. + """ + optimizer_config.params = self.network.parameters() + optimizer = instantiate(optimizer_config) + scheduler_config.optimizer = optimizer + scheduler = instantiate(scheduler_config) + + return optimizer, scheduler + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + if self.config.ema.enabled: + self.ema.to(dtype=torch.float32) + self.network = self.network.to(dtype=self.precision, memory_format=memory_format) + self.loss = self.loss.to(dtype=self.precision, memory_format=memory_format) + + def state_dict( + self, destination: dict[str, Any] = None, prefix: str = "", keep_vars: bool = False + ) -> dict[str, Any]: + original_state_dict = super(TokenizerModel, self).state_dict(destination, prefix, keep_vars) + + # Filter out '.loss' and 'ema.loss-' keys from the state dict. + filtered_state_dict = {k: v for k, v in original_state_dict.items() if not k.startswith("loss.")} + filtered_state_dict = {k: v for k, v in filtered_state_dict.items() if not k.startswith("ema.loss-")} + filtered_state_dict = { + k: v for k, v in filtered_state_dict.items() if not k.startswith("network.encoder.patcher") + } + filtered_state_dict = { + k: v for k, v in filtered_state_dict.items() if not k.startswith("network.decoder.unpatcher") + } + + return filtered_state_dict + + def load_state_dict(self, state_dict: Any, strict: bool = True) -> None: + own_state = self.state_dict() + filtered_state_dict = {k: v for k, v in state_dict.items() if k in own_state} + + # Load only filtered state dict. + super(TokenizerModel, self).load_state_dict(filtered_state_dict, strict=False) + + # If strict is True, ensure all parameters are loaded (except the excluded ones) + missing_keys = set(own_state.keys()) - set(filtered_state_dict.keys()) + if missing_keys and strict: + raise KeyError(f"Missing keys in state_dict: {missing_keys}") + + def _on_before_network_forward(self, data_batch: dict[str, torch.Tensor]) -> None: + consistency_loss = self.loss.loss_modules[VIDEO_CONSISTENCY_LOSS] + if hasattr(consistency_loss, "enabled") and consistency_loss.enabled: + _input_key = self.get_input_key(data_batch) + if _input_key is self.video_key: + data_batch[_input_key] = consistency_loss.shuffle(data_batch[_input_key]) + return + + def _on_after_network_forward( + self, data_batch: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor] + ) -> None: + consistency_loss = self.loss.loss_modules[VIDEO_CONSISTENCY_LOSS] + if hasattr(consistency_loss, "enabled") and consistency_loss.enabled: + _input_key = self.get_input_key(data_batch) + if _input_key is self.video_key: + data_batch[_input_key] = consistency_loss.unshuffle(data_batch[_input_key]) + output_batch[RECON_CONSISTENCY_KEY] = torch.ones_like(output_batch[RECON_KEY]) * output_batch[RECON_KEY] + output_batch[RECON_KEY] = consistency_loss.unshuffle(output_batch[RECON_KEY]) + return + + def _network_forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + # A callback proxy to modify the input before the forward pass. + self._on_before_network_forward(data_batch) + + # Do the forward pass. + tensor_batch = data_batch[self.get_input_key(data_batch)] + output_batch = self.network(tensor_batch) + output_batch = output_batch if self.network.training else output_batch._asdict() + + # A callback proxy to modify the output after the forward pass. + self._on_after_network_forward(data_batch, output_batch) + return output_batch + + def training_step( + self, + data_batch: dict[str, torch.Tensor], + iteration: int, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + _input_key = self.get_input_key(data_batch) + output_dict = self._network_forward(data_batch) + input_images, recon_images = data_batch[_input_key], output_dict[RECON_KEY] + + # pass loss_mask to loss computation + inputs = {INPUT_KEY: input_images, MASK_KEY: data_batch.get("loss_mask", torch.ones_like(input_images))} + + loss_dict, loss_value = self.loss(inputs, output_dict, iteration) + return dict({PREDICTION: recon_images, **loss_dict}), loss_value + + @torch.no_grad() + def validation_step( + self, + data_batch: dict[str, torch.Tensor], + iteration: int, + ema_model: bool = False, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + _input_key = self.get_input_key(data_batch) + output_dict = self._network_forward(data_batch) + input_images, recon_images = data_batch[_input_key], output_dict[RECON_KEY] + + # pass loss_mask to loss computation + inputs = {INPUT_KEY: input_images, MASK_KEY: data_batch.get("loss_mask", torch.ones_like(input_images))} + + loss_dict, loss_value = self.loss(inputs, output_dict, iteration) + metric_dict = self.metric(input_images, output_dict, iteration) + loss_dict.update(metric_dict) + prediction_key = EMA_PREDICTION if ema_model else PREDICTION + return dict({prediction_key: recon_images, **loss_dict}), loss_value + + @torch.inference_mode() + def forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + _input_key = self.get_input_key(data_batch) + output_dict = self._network_forward(data_batch) + return dict({PREDICTION: output_dict[RECON_KEY]}) diff --git a/cosmos_predict1/tokenizer/training/train.py b/cosmos_predict1/tokenizer/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4c2cf58c1f68327c2739326ead7e557c13fe25 --- /dev/null +++ b/cosmos_predict1/tokenizer/training/train.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import os + +from loguru import logger as logging + +from cosmos_predict1.utils.config import Config, pretty_print_overrides +from cosmos_predict1.utils.config_helper import get_config_module, override +from cosmos_predict1.utils.lazy_config import instantiate +from cosmos_predict1.utils.lazy_config.lazy import LazyConfig + + +@logging.catch(reraise=True) +def launch(config: Config, args: argparse.Namespace) -> None: + # Check that the config is valid + config.validate() + # Freeze the config so developers don't change it during training. + config.freeze() # type: ignore + trainer = config.trainer.type(config) + # Create the model + model = instantiate(config.model) + model.on_model_init_end() + dataloader_train = instantiate(config.dataloader_train) + dataloader_val = instantiate(config.dataloader_val) + # Start training + trainer.train( + model, + dataloader_train, + dataloader_val, + ) + + +if __name__ == "__main__": + # Usage: torchrun --nproc_per_node=1 -m scripts.train --config=projects/tutorials/mnist/config.py + + # Get the config file from the input arguments. + parser = argparse.ArgumentParser(description="Training") + parser.add_argument("--config", help="Path to the config file", required=True) + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--dryrun", + action="store_true", + help="Do a dry run without training. Useful for debugging the config.", + ) + args = parser.parse_args() + config_module = get_config_module(args.config) + config = importlib.import_module(config_module).make_config() + config = override(config, args.opts) + if args.dryrun: + logging.info( + "Config:\n" + config.pretty_print(use_color=True) + "\n" + pretty_print_overrides(args.opts, use_color=True) + ) + os.makedirs(config.job.path_local, exist_ok=True) + LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") + print(f"{config.job.path_local}/config.yaml") + else: + # Launch the training job. + launch(config, args) diff --git a/cosmos_predict1/tokenizer/training/trainer.py b/cosmos_predict1/tokenizer/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e63cb884cbc11d95f40f8a7338ec0e72439bc54e --- /dev/null +++ b/cosmos_predict1/tokenizer/training/trainer.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.utils.data + +from cosmos_predict1.tokenizer.training.checkpointer import TokenizerCheckpointer +from cosmos_predict1.utils import ema, misc +from cosmos_predict1.utils.model import Model +from cosmos_predict1.utils.trainer import Trainer + + +class TokenizerTrainer(Trainer): + """The tokenizers traine, extended from Trainer. + + It extends model training functionality. + + Attributes: + checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. + training_timer (misc.Timer): Timer object to time code blocks and functions. + """ + + def __init__(self, config): + super(TokenizerTrainer, self).__init__(config) + self.model_config = config.model.config + self.checkpointer = TokenizerCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) + + @torch.no_grad() + def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: + """Validate on the full validation dataset. + + Args: + model (Model): The PyTorch model. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + iteration (int): Current iteration number. + """ + self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) + model.eval() + # Evaluate on the full validation set. + for val_iter, data_batch in enumerate(dataloader_val): + if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: + break + data_batch = misc.to(data_batch, device="cuda") + self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) + output_batch, _ = model.validation_step(data_batch, iteration) + with ema.ema_scope(model, enabled=model.config.ema.enabled): + ema_output_batch, loss = model.validation_step(data_batch, iteration, ema_model=True) + output_batch.update(ema_output_batch) + self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) + self.callbacks.on_validation_end(model, iteration=iteration) diff --git a/cosmos_predict1/utils/__init__.py b/cosmos_predict1/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dac9a4d7496eb38831f1f3c820a90d50e25e2a7e --- /dev/null +++ b/cosmos_predict1/utils/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/utils/base_world_generation_pipeline.py b/cosmos_predict1/utils/base_world_generation_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..673b150c2ff73907d3fbfa3a6d8d537b760b61dc --- /dev/null +++ b/cosmos_predict1/utils/base_world_generation_pipeline.py @@ -0,0 +1,357 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from abc import ABC +from typing import Any + +import numpy as np +import torch + +from cosmos_predict1.auxiliary.guardrail.common import presets as guardrail_presets +from cosmos_predict1.auxiliary.t5_text_encoder import CosmosT5TextEncoder + + +class BaseWorldGenerationPipeline(ABC): + def __init__( + self, + inference_type: str | None = None, + checkpoint_dir: str | None = None, + checkpoint_name: str | None = None, + has_text_input: bool = False, + offload_network: bool = False, + offload_tokenizer: bool = False, + offload_text_encoder_model: bool = False, + offload_guardrail_models: bool = False, + disable_guardrail: bool = False, + ): + """Initialize base world generation pipeline. + + This abstract base class provides core functionality for world generation models including: + - Model loading and initialization + - Text encoding and embedding + - Safety checks and content filtering + - Memory management through model offloading + + Args: + inference_type: The type of inference pipeline ("text2world" or "video2world") + checkpoint_dir: Root directory containing model checkpoints + checkpoint_name: Name of the specific checkpoint file to load + has_text_input: Whether the pipeline takes text input for world generation + offload_network: If True, moves main model to CPU after inference + offload_tokenizer: If True, moves tokenizer to CPU after use + offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding + offload_guardrail_models: If True, moves safety models to CPU after checks + disable_guardrail: If True, disable guardrail + """ + self.inference_type = inference_type + self.checkpoint_dir = checkpoint_dir + self.checkpoint_name = checkpoint_name + self.has_text_input = has_text_input + + # Add offloading flags + self.offload_network = offload_network + self.offload_tokenizer = offload_tokenizer + self.offload_text_encoder_model = offload_text_encoder_model + self.offload_guardrail_models = offload_guardrail_models + + self.disable_guardrail = disable_guardrail + + # Initialize model instances + self.text_guardrail = None + self.video_guardrail = None + self.text_encoder = None + self.model = None + + self._load_model() + + if not self.offload_text_encoder_model: + self._load_text_encoder_model() + if not self.disable_guardrail and not self.offload_guardrail_models: + if self.has_text_input: + self._load_text_guardrail() + self._load_video_guardrail() + if not self.offload_network: + self._load_network() + if not self.offload_tokenizer: + self._load_tokenizer() + + def _load_tokenizer(self): + pass + + def _load_network(self): + pass + + def _load_model(self, checkpoint_name: str) -> Any: + """Load the world generation model from a checkpoint. + + This abstract method must be implemented by subclasses to load their specific + model architecture and weights. + + Args: + checkpoint_name: Path to the model checkpoint file + + Returns: + The loaded model instance + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + pass + + def _load_text_encoder_model(self): + """Load the T5 text encoder model. + + Initializes and loads the T5 encoder model used for converting text prompts + into embeddings that condition the world generation model. + + Returns: + Loaded T5 text encoder model instance + """ + self.text_encoder = CosmosT5TextEncoder(cache_dir=os.path.join(self.checkpoint_dir, "google-t5/t5-11b")) + + def _load_text_guardrail(self): + """Load text safety classifier models. + + Initializes models used for checking input prompts against safety policies. + Models are loaded from the specified guardrail directory. + """ + self.text_guardrail = guardrail_presets.create_text_guardrail_runner(checkpoint_dir=self.checkpoint_dir) + + def _load_video_guardrail(self): + """Load video safety classifier models. + + Initializes models used for validating generated video content against + safety policies. Models are loaded from the specified guardrail directory. + """ + self.video_guardrail = guardrail_presets.create_video_guardrail_runner(checkpoint_dir=self.checkpoint_dir) + + def _offload_network(self): + if self.model.model: + del self.model.model + self.model.model = None + gc.collect() + torch.cuda.empty_cache() + + def _offload_tokenizer(self): + if self.model.tokenizer: + del self.model.tokenizer + self.model.tokenizer = None + gc.collect() + torch.cuda.empty_cache() + + def _offload_guardrail_models(self): + """Offload safety classifier models to reduce memory usage. + + Moves safety models to CPU and clears GPU memory if they are no longer needed. + This helps manage memory when processing multiple inputs sequentially. + """ + if self.text_guardrail: + del self.text_guardrail + self.text_guardrail = None + if self.video_guardrail: + del self.video_guardrail + self.video_guardrail = None + gc.collect() + torch.cuda.empty_cache() + + def _offload_text_encoder_model(self): + """Offload T5 text encoder to reduce memory usage. + + Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete. + This helps manage memory when processing multiple inputs sequentially. + """ + if self.text_encoder: + del self.text_encoder + self.text_encoder = None + gc.collect() + torch.cuda.empty_cache() + + def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """Generate world latents using the model. + + This abstract method must be implemented by subclasses to define their specific + generation process. + + Args: + *args: Variable positional arguments for model inference + **kwargs: Variable keyword arguments for model inference + + Returns: + torch.Tensor: Generated world representation tensor + """ + pass + + def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """Generate world representation with memory management. + + Handles loading the model before inference and offloading afterward if enabled. + This helps minimize GPU memory usage during inference. + + Args: + *args: Arguments passed to _run_model + **kwargs: Keyword arguments passed to _run_model + + Returns: + np.ndarray: Generated world representation as numpy array + """ + pass + + def _run_guardrail_on_prompt(self, prompt: str) -> bool: + """Check if prompt meets safety requirements. + + Validates the input prompt against safety policies using loaded guardrail models. + + Args: + prompt: Raw text prompt to validate + + Returns: + bool: True if prompt passes all safety checks, False otherwise + """ + return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail) + + def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool: + """Check prompt safety with memory management. + + Validates prompt safety while handling model loading/offloading to manage memory. + + Args: + prompt: Raw text prompt to validate + + Returns: + bool: True if prompt passes all safety checks, False otherwise + """ + if self.offload_guardrail_models: + self._load_text_guardrail() + + is_safe = self._run_guardrail_on_prompt(prompt) + + if self.offload_guardrail_models: + self._offload_guardrail_models() + + return is_safe + + def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None: + """Check if video meets safety requirements. + + Validates generated video content against safety policies using guardrail models. + + Args: + video: Video frames to validate + + Returns: + np.ndarray: Processed video if safe, None if unsafe + """ + return guardrail_presets.run_video_guardrail(video, self.video_guardrail) + + def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None: + """Check if generated video meets safety requirements. + + Args: + video: Video frames to validate + + Returns: + np.ndarray: Processed video frames if safe, None otherwise + + Note: + Guardrail models are offloaded after checks if enabled. + """ + if self.offload_guardrail_models: + self._load_video_guardrail() + + video = self._run_guardrail_on_video(video) + + if self.offload_guardrail_models: + self._offload_guardrail_models() + return video + + def _run_text_embedding_on_prompt( + self, prompts: list[str], **kwargs: Any + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Convert text prompts to embeddings. + + Processes text prompts into embedding tensors that condition the generation model. + + Args: + prompts: List of text prompts to encode + **kwargs: Additional arguments for text encoding + + Returns: + tuple containing: + - List of text embedding tensors for each prompt + - List of attention masks for each embedding + """ + + embeddings = [] + masks = [] + for prompt in prompts: + embedding, mask = self.text_encoder.encode_prompts( + [prompt], + **kwargs, + ) + embeddings.append(embedding) + masks.append(mask) + + return embeddings, masks + + def _run_text_embedding_on_prompt_with_offload( + self, prompts: list[str], **kwargs: Any + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Convert text prompt into embeddings using T5 encoder. + + Args: + prompt: Processed and validated text prompt + + Returns: + Text embedding tensor to condition diffusion model + + Note: + T5 model is offloaded after encoding if enabled. + """ + if self.offload_text_encoder_model: + self._load_text_encoder_model() + + embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs) + + if self.offload_text_encoder_model: + self._offload_text_encoder_model() + return embeddings, masks + + def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray: + """Decode model outputs into final world representation. + + This abstract method must be implemented by subclasses to convert raw model + outputs into their specific world representation format. + + Args: + samples: Raw output tensor from the generation model + + Returns: + np.ndarray: Decoded world representation + """ + pass + + def generate(self, *args: Any, **kwargs: Any): + """Generate world representation. + + This abstract method must be implemented by subclasses to convert raw model + outputs into their specific world representation format. + + Args: + *args: Variable positional arguments for model inference + **kwargs: Variable keyword arguments for model inference + """ + pass diff --git a/cosmos_predict1/utils/callback.py b/cosmos_predict1/utils/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f4b32beed9e4fb9a279a7d5c28fa278bfa4478 --- /dev/null +++ b/cosmos_predict1/utils/callback.py @@ -0,0 +1,403 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +import warnings +from typing import TYPE_CHECKING, Any, Callable, Optional + +import omegaconf +import torch +import torch.utils.data +import tqdm + +from cosmos_predict1.utils import distributed, log +from cosmos_predict1.utils.lazy_config import instantiate +from cosmos_predict1.utils.misc import get_local_tensor_if_DTensor + +if TYPE_CHECKING: + from cosmos_predict1.utils.config import Config + from cosmos_predict1.utils.model import Model + from cosmos_predict1.utils.trainer import Trainer + + +class CallBackGroup: + """A class for hosting a collection of callback objects. + + It is used to execute callback functions of multiple callback objects with the same method name. + When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs + self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match. + + Attributes: + _callbacks (list[Callback]): List of callback objects. + """ + + def __init__(self, config: Config, trainer: Trainer) -> None: + """Initializes the list of callback objects. + + Args: + config (Config): The config object for the codebase. + trainer (Trainer): The main trainer. + """ + self._callbacks = [] + callback_configs = config.trainer.callbacks + if callback_configs: + if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig): + warnings.warn( + "The 'config.trainer.callbacks' parameter should be a dict instead of a list. " + "Please update your code", + DeprecationWarning, + stacklevel=2, + ) + callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)} + for callback_name, current_callback_cfg in callback_configs.items(): + if "_target_" not in current_callback_cfg: + log.critical( + f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}" + ) + continue + log.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}") + _callback = instantiate(current_callback_cfg) + assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback." + _callback.config = config + _callback.trainer = trainer + self._callbacks.append(_callback) + + def __getattr__(self, method_name: str) -> Callable: + """Loops through the callback objects to call the corresponding callback function. + + Args: + method_name (str): Callback method name. + """ + + def multi_callback_wrapper(*args, **kwargs) -> None: + for callback in self._callbacks: + assert hasattr(callback, method_name) + method = getattr(callback, method_name) + assert callable(method) + _ = method(*args, **kwargs) + + return multi_callback_wrapper + + +class Callback: + """The base class for all callbacks. + + All callbacks should inherit from this class and adhere to the established method names and signatures. + """ + + def __init__(self, config: Optional["Config"] = None, trainer: Optional["Trainer"] = None): + """Initializes a Callback object. + + Args: + config (Optional[Config]): The configuration object for the codebase, if available. + trainer (Optional[Trainer]): The main trainer handling the training loop, if available. + + Notes: + The config and trainer parameters are optional to maintain backward compatibility. + In future releases, these parameters will be removed. Upon using these parameters, a deprecation + warning will be issued. + + """ + if config is not None or trainer is not None: + warnings.warn( + "The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. " + "Please update your code to create Callback instances without these parameters.", + DeprecationWarning, + stacklevel=2, + ) + del config, trainer + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + pass + + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + pass + + def on_before_forward(self, iteration: int = 0) -> None: + pass + + def on_after_forward(self, iteration: int = 0) -> None: + pass + + def on_before_backward( + self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0 + ) -> None: + pass + + def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None: + pass + + def on_before_dataloading(self, iteration: int = 0) -> None: + pass + + def on_after_dataloading(self, iteration: int = 0) -> None: + pass + + def on_optimizer_init_start(self) -> None: + pass + + def on_optimizer_init_end(self) -> None: + pass + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + pass + + def on_before_zero_grad( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + iteration: int = 0, + ) -> None: + pass + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + pass + + def on_validation_start( + self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 + ) -> None: + pass + + def on_validation_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + pass + + def on_validation_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + pass + + def on_validation_end(self, model: Model, iteration: int = 0) -> None: + pass + + def on_load_checkpoint_start(self, model: Model) -> None: + pass + + def on_load_checkpoint_end(self, model: Model) -> None: + pass + + def on_load_checkpoint(self, model: Model, state_dict: dict[Any]) -> None: + pass + + def on_save_checkpoint_start(self, model: Model, iteration: int = 0) -> None: + pass + + def on_save_checkpoint_end(self, model: Model, iteration: int = 0) -> None: + pass + + def on_save_checkpoint_success(self, iteration: int = 0) -> None: + pass + + def on_save_checkpoint(self, model: Model, state_dict: dict[Any]) -> None: + pass + + def on_train_end(self, model: Model, iteration: int = 0) -> None: + pass + + def on_app_end(self) -> None: + pass + + +class EMAModelCallback(Callback): + """The callback class for tracking EMA model weights.""" + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + # Set up the EMA model weight tracker. + if model.config.ema.enabled: + assert hasattr(model, "ema"), "EMA should be initialized from Model" + # EMA model must be kept in FP32 precision. + model.ema = model.ema.to(dtype=torch.float32) + else: + assert not hasattr(model, "ema"), "There should be no EMA initialized." + + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + # Update the EMA model with the new regular weights. + if model.config.ema.enabled: + model.ema.update_average(model, iteration) + + +class ProgressBarCallback(Callback): + """The callback class for visualizing the training/validation progress bar in the console.""" + + @distributed.rank0_only + def on_train_start(self, model: Model, iteration: int = 0) -> None: + self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") + + @distributed.rank0_only + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.train_pbar.update() + + @distributed.rank0_only + def on_validation_start( + self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 + ) -> None: + if self.config.trainer.max_val_iter is not None: + num_iter = self.config.trainer.max_val_iter + else: + num_iter = len(dataloader_val) + assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}" + self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False) + + @distributed.rank0_only + def on_validation_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.val_pbar.update() + + @distributed.rank0_only + def on_validation_end(self, model: Model, iteration: int = 0) -> None: + self.val_pbar.close() + + @distributed.rank0_only + def on_train_end(self, model: Model, iteration: int = 0) -> None: + self.trainer.checkpointer.finalize() + self.train_pbar.close() + + +class IterationLoggerCallback(Callback): + """The callback class for visualizing the training/validation progress bar in the console.""" + + @distributed.rank0_only + def on_train_start(self, model: Model, iteration: int = 0) -> None: + # self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") + self.start_iteration_time = time.time() + self.elapsed_iteration_time = 0 + + @distributed.rank0_only + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + self.start_iteration_time = time.time() + + @distributed.rank0_only + def on_training_step_end( + self, + model: Model, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.elapsed_iteration_time += time.time() - self.start_iteration_time + + if iteration % self.config.trainer.logging_iter == 0: + avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter + log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}") + + self.elapsed_iteration_time = 0 + + +class GradClipCallback(Callback): + """The callback class for gradient clipping.""" + + def __init__( + self, + config: Optional["Config"] = None, + trainer: Optional["Trainer"] = None, + grad_clip_norm: float = 1.0, + ): + super().__init__(config, trainer) + self.grad_clip_norm = grad_clip_norm + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + grad_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model_ddp.module.parameters(), max_norm=self.grad_clip_norm) + + +class LowPrecisionCallback(Callback): + """The callback class handling low precision training""" + + def __init__(self, update_iter: int, config: Optional["Config"] = None, trainer: Optional["Trainer"] = None): + super().__init__(config, trainer) + self.update_iter = update_iter + + def on_train_start(self, model: Model, iteration: int = 0) -> None: + assert model.precision in [ + torch.bfloat16, + torch.float16, + torch.half, + ], "LowPrecisionCallback must use a low precision dtype." + self.precision_type = model.precision + + def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + for k, v in data.items(): + if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): + data[k] = v.to(dtype=self.precision_type) + + def on_validation_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: + for k, v in data.items(): + if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): + data[k] = v.to(dtype=self.precision_type) + + def on_before_zero_grad( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + iteration: int = 0, + ) -> None: + if iteration % self.update_iter == 0: + if getattr(optimizer, "master_weights", False): + params, master_params = [], [] + for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master): + for p, p_master in zip(group["params"], group_master["params"]): + params.append(get_local_tensor_if_DTensor(p.data)) + master_params.append(p_master.data) + torch._foreach_copy_(params, master_params) diff --git a/cosmos_predict1/utils/callbacks/grad_clip.py b/cosmos_predict1/utils/callbacks/grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4f320b6f79e1e289117d8190b5f6df52cf64ae --- /dev/null +++ b/cosmos_predict1/utils/callbacks/grad_clip.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from cosmos_predict1.utils import distributed +from cosmos_predict1.utils.callback import Callback + + +@torch.jit.script +def _fused_nan_to_num(params: List[torch.Tensor]): + for param in params: + torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param) + + +class GradClip(Callback): + def __init__( + self, clip_norm=1.0, force_finite: bool = True, model_key: Optional[str] = None, fsdp_enabled: bool = False + ): + self.clip_norm = clip_norm + self.force_finite = force_finite + self.model_key = model_key + self.fsdp_enabled = fsdp_enabled + + def on_before_optimizer_step( + self, + model_ddp: distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int = 0, + ) -> None: + del optimizer, scheduler + if isinstance(model_ddp, distributed.DistributedDataParallel): + model = model_ddp.module + else: + model = model_ddp + + # select sub-network if specified + if self.model_key is not None: + items = self.model_key.split(".") + for item in items: + model = getattr(model, item) + + if self.force_finite: + params = [] + for param in model.parameters(): + if param.grad is not None: + params.append(param.grad) + # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) + _fused_nan_to_num(params) + + # check if FSDP is used + # total_norm + if isinstance(model, FSDP) and self.fsdp_enabled: + model.clip_grad_norm_(self.clip_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) diff --git a/cosmos_predict1/utils/checkpointer.py b/cosmos_predict1/utils/checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..91142e0bc9ded0c3d6e6d834c824db6e5551738e --- /dev/null +++ b/cosmos_predict1/utils/checkpointer.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import threading +from typing import TYPE_CHECKING + +import torch + +from cosmos_predict1.utils import callback, distributed, log, misc +from cosmos_predict1.utils.model import Model + +if TYPE_CHECKING: + from cosmos_predict1.utils.config import CheckpointConfig, JobConfig + + +class Checkpointer: + """The checkpointer class. Supports checkpoint saving/loading to local disk.""" + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + """Constructor of the checkpointer. + + Args: + config_checkpoint (CheckpointConfig): The config object for the checkpointer. + """ + # Set the callback functions. + self.callbacks = callbacks + self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" + self.strict_resume = config_checkpoint.strict_resume + self.load_path = config_checkpoint.load_path or None + self.load_training_state = config_checkpoint.load_training_state + self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state + self.save_thread = None + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + checkpoint_file = f"iter_{iteration:09}.pt" + + if distributed.get_rank() == 0: + state_dict = dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + scheduler=scheduler.state_dict(), + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + state_dict = misc.to(state_dict, device="cpu") + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + @misc.timer("checkpoint saving (local)") + def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None: + """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). + + Args: + state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) + os.makedirs(self.checkpoint_dir_local, exist_ok=True) + try: + torch.save(state_dict, checkpoint_path) + if rank == 0: + self._write_latest_checkpoint_file(checkpoint_file) + log.success(f"Saved checkpoint (local): {checkpoint_path}") + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to save (local): {e}") + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + latest_checkpoint_file = self._read_latest_checkpoint_file() + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_dir = self.checkpoint_dir_local + checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) + resume = True + only_resume_scheduler = True + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + resume = self.load_training_state + only_resume_scheduler = self.only_load_scheduler_state + else: + # 3. Randomly initialize the model parameters and train from scratch. + checkpoint_path = None + resume = False + only_resume_scheduler = False + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + self.callbacks.on_load_checkpoint(model, state_dict=state_dict) + # Load the state dicts. + log.info("- Loading the model...") + if "model" in state_dict: + model.load_state_dict(state_dict["model"], strict=self.strict_resume) + else: + model.load_state_dict(state_dict, strict=self.strict_resume) + if resume or only_resume_scheduler: + iteration = state_dict["iteration"] + assert scheduler + log.info("- Loading the scheduler...") + scheduler.load_state_dict(state_dict["scheduler"]) + scheduler.last_epoch = iteration + else: + iteration = 0 + if resume: + assert optimizer + log.info("- Loading the optimizer...") + optimizer.load_state_dict(state_dict["optimizer"]) + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(state_dict["grad_scaler"]) + log.success(f"Done with loading the checkpoint (iteration {iteration}).") + else: + log.success("Done with loading the checkpoint.") + else: + # Checkpoint not found and not specified. We will train everything from scratch. + iteration = 0 + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + return iteration + + def _read_latest_checkpoint_file(self) -> str | None: + """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. + + Returns: + checkpoint_file (str | None): file name of the latest saved checkpoint. + """ + checkpoint_file = None + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + if os.path.isfile(latest_path): + checkpoint_file = open(latest_path).read().strip() + return checkpoint_file + + def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: + """Track the file name of the latest saved checkpoint. + + Args: + checkpoint_file (str): file name of the latest saved checkpoint. + """ + content = f"{checkpoint_file}\n" + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + with open(latest_path, "w") as file: + file.write(content) + + def _check_checkpoint_exists(self, checkpoint_path: str) -> None: + """If the file checkpoint_path does not exist, raise an error. + + Args: + checkpoint_path (str): full path to the checkpoint. + """ + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"File not found (local): {checkpoint_path}") + + def finalize(self) -> None: + """Finalize the checkpointer.""" + if self.save_thread: + self.save_thread.join() diff --git a/cosmos_predict1/utils/config.py b/cosmos_predict1/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..fe08db8872dd9c3c9f821ae68ff1d1544fe6cc90 --- /dev/null +++ b/cosmos_predict1/utils/config.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from typing import Any, Dict, Optional, Type, TypeVar, Union + +import attrs +import torch +from megatron.core import ModelParallelConfig + +from cosmos_predict1.utils import callback +from cosmos_predict1.utils.lazy_config import LazyCall as L +from cosmos_predict1.utils.lazy_config import LazyDict +from cosmos_predict1.utils.misc import Color + +T = TypeVar("T") + + +def _is_attrs_instance(obj: object) -> bool: + """ + Helper function to check if an object is an instance of an attrs-defined class. + + Args: + obj: The object to check. + + Returns: + bool: True if the object is an instance of an attrs-defined class, False otherwise. + """ + return hasattr(obj, "__attrs_attrs__") + + +def make_freezable(cls: T) -> T: + """ + A decorator that adds the capability to freeze instances of an attrs-defined class. + + NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need + to hack on a "_is_frozen" attribute. + + This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. + Once an instance is frozen, its attributes cannot be changed. It also recursively freezes + any attrs-defined objects that are attributes of the class. + + Usage: + @make_freezable + @attrs.define(slots=False) + class MyClass: + attribute1: int + attribute2: str + + obj = MyClass(1, 'a') + obj.freeze() # Freeze the instance + obj.attribute1 = 2 # Raises AttributeError + + Args: + cls: The class to be decorated. + + Returns: + The decorated class with added freezing capability. + """ + + if not hasattr(cls, "__dict__"): + raise TypeError( + "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " + "class was defined with `@attrs.define(slots=False)`" + ) + + original_setattr = cls.__setattr__ + + def setattr_override(self, key, value) -> None: # noqa: ANN001 + """ + Override __setattr__ to allow modifications during initialization + and prevent modifications once the instance is frozen. + """ + if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": + raise AttributeError("Cannot modify frozen instance") + original_setattr(self, key, value) # type: ignore + + cls.__setattr__ = setattr_override # type: ignore + + def freeze(self: object) -> None: + """ + Freeze the instance and all its attrs-defined attributes. + """ + for _, value in attrs.asdict(self, recurse=False).items(): + if _is_attrs_instance(value) and hasattr(value, "freeze"): + value.freeze() + self._is_frozen = True # type: ignore + + cls.freeze = freeze # type: ignore + + return cls + + +def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: + """ + Recursively pretty prints attrs objects with color. + """ + + assert attrs.has(obj.__class__) + + lines: list[str] = [] + for attribute in attrs.fields(obj.__class__): + value = getattr(obj, attribute.name) + if attrs.has(value.__class__): + if use_color: + lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":") + else: + lines.append(" " * indent + "* " + attribute.name + ":") + lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color)) + else: + if use_color: + lines.append( + " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value) + ) + else: + lines.append(" " * indent + "* " + attribute.name + ": " + str(value)) + return "\n".join(lines) + + +def pretty_print_overrides(overrides: Optional[list[str]] = None, use_color: bool = False) -> str: + """ + Pretty prints overrides. + """ + + lines: list[str] = [] + lines.append(Color.cyan("* ") + Color.green("overrides") + ": ") + for override in overrides: + if override == "--": + continue + attribute_name, attribute_value = override.split("=") + if use_color: + lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value)) + else: + lines.append(" " + "* " + attribute_name + ": " + str(attribute_value)) + + return "\n".join(lines) + + +@make_freezable +@attrs.define(slots=False) +class JobConfig: + # Project name. + project: str = "" + # Experiment name. + group: str = "" + # Run/job name. + name: str = "" + + @property + def path(self) -> str: + return f"{self.project}/{self.group}/{self.name}" + + @property + def path_local(self) -> str: + local_root = os.environ.get("OUTPUT_ROOT", "checkpoints") + return f"{local_root}/{self.path}" + + +@make_freezable +@attrs.define(slots=False) +class EMAConfig: + # Enable tracking a set of exponential moving average (EMA) weights. + enabled: bool = False + # EMA decay rate. + beta: float = 0.9999 + # Enable removing "_orig_mod-" from buffer names that is added by torch.compile + torch_compile_buffer_renaming: bool = False + + +@make_freezable +@attrs.define(slots=False) +class DDPConfig: + # Traverse the computation graph to find parameters that don't receive gradients. + find_unused_parameters: bool = False + # Set to True if the computation graph does not change during the whole training loop. + static_graph: bool = True + # Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere. + broadcast_buffers: bool = True + + +@make_freezable +@attrs.define(slots=False) +class CuDNNConfig: + # Set to True for better reproducibility of the results (only using deterministic cudnn functions). + deterministic: bool = False + # If set to True, cudnn will benchmark several algorithms and pick the fastest one. + benchmark: bool = True + + +@make_freezable +@attrs.define(slots=False) +class JITConfig: + # Enable exporting a JIT compiled model. + enabled: bool = False + # Input tensor shape, for example input. + input_shape: Union[list[int], None] = None + # Device to compile onto. + device: str = "cuda" + # # Data type to compile onto. + dtype: str = "bfloat16" + # Strict mode for PyTorch JIT. + strict: bool = True + + +@make_freezable +@attrs.define(slots=False) +class CheckpointConfig: + # possible checkpoint class + type: Optional[Dict] = None + # for dcp, whether to use async mode + dcp_async_mode_enabled: bool = False + # Save the checkpoint every N iterations. + save_iter: int = 999999999 + # Path of model weights to resume the checkpoint from. + load_path: str = "" + # Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path. + load_training_state: bool = False + # Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored. + only_load_scheduler_state: bool = False + # Load state_dict to the models in strict mode. + strict_resume: bool = True + # Print detailed information during checkpoint saving/loading. + verbose: bool = True + # Configs for JIT compiling EMA model. + jit: JITConfig = attrs.field(factory=JITConfig) + # keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"] + keys_not_to_resume: list[str] = [] + # Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer). + broadcast_via_filesystem: bool = False + load_ema_to_reg: bool = False + async_saving: bool = True + + +@make_freezable +@attrs.define(slots=False) +class TrainerConfig: + from cosmos_predict1.utils.trainer import Trainer + + type: Type[Trainer] = Trainer + # Set the callback class. + # Defaults to the callbacks below. + callbacks: LazyDict = LazyDict( + dict( + ema=L(callback.EMAModelCallback)(), + progress_bar=L(callback.ProgressBarCallback)(), + ) + ) + # distributed parallelism strategy + distributed_parallelism: str = "ddp" + # Distributed data parallel configs. + ddp: DDPConfig = attrs.field(factory=DDPConfig) + # cuDNN configs. + cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig) + # Set the random seed. + seed: int = 0 + # Gradient scaler arguments (for torch.amp.GradScaler). + grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False)) + # Maximum number of iterations to train the model. + max_iter: int = 999999999 + # Maximum number of iterations to validate the model. If None, validate on the entire dataset. + max_val_iter: int | None = None + # How often we log the training stats. + logging_iter: int = 100 + # Whether we want to run the validation routines. + run_validation: bool = True + # How often we evaluate on the validation set. + validation_iter: int = 999999999 + # Kill the process after N seconds since the last iteration (usually means dead job). + timeout_period: int = 999999999 + # Tensor memory organization format. + memory_format: torch.memory_format = torch.preserve_format + # Gradient accumulation (update step every N iteration). + grad_accum_iter: int = 1 + # # Profiling config + # profiling: Profiling = attrs.field(factory=Profiling) + + +@make_freezable +@attrs.define(slots=False) +class Config: + """Config for a job. + + See /README.md/Configuration System for more info. + """ + + # Model configs. + model: LazyDict + # Optimizer configs. + optimizer: LazyDict = LazyDict(dict(dummy=None)) + # Scheduler configs. + scheduler: LazyDict = LazyDict(dict(dummy=None)) + # Training data configs. + dataloader_train: LazyDict = LazyDict(dict(dummy=None)) + # Validation data configs. + dataloader_val: LazyDict = LazyDict(dict(dummy=None)) + + # Training job configs. + job: JobConfig = attrs.field(factory=JobConfig) + + # Trainer configs. + trainer: TrainerConfig = attrs.field(factory=TrainerConfig) + + # Megatron-Core configs + model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig) + + # Checkpointer configs. + checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig) + + def pretty_print(self, use_color: bool = False) -> str: + return _pretty_print_attrs_instance(self, 0, use_color) + + # Training job configs. + job: JobConfig = attrs.field(factory=JobConfig) + + def to_dict(self) -> dict[str, Any]: + return attrs.asdict(self) + + def validate(self) -> None: + """Validate that the config has all required fields.""" + assert self.job.project != "", "Project name is required." + assert self.job.group != "", "Group name is required." + assert self.job.name != "", "Job name is required." diff --git a/cosmos_predict1/utils/config_helper.py b/cosmos_predict1/utils/config_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd1d21934aa68fdb6bcae35777c38a8dff644d2 --- /dev/null +++ b/cosmos_predict1/utils/config_helper.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +import pkgutil +import sys +from dataclasses import fields as dataclass_fields +from dataclasses import is_dataclass +from typing import Any, Dict, Optional + +import attr +import attrs +from hydra import compose, initialize +from hydra.core.config_store import ConfigStore +from omegaconf import DictConfig, OmegaConf + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.config import Config + + +def is_attrs_or_dataclass(obj) -> bool: + """ + Check if the object is an instance of an attrs class or a dataclass. + + Args: + obj: The object to check. + + Returns: + bool: True if the object is an instance of an attrs class or a dataclass, False otherwise. + """ + return is_dataclass(obj) or attr.has(type(obj)) + + +def get_fields(obj): + """ + Get the fields of an attrs class or a dataclass. + + Args: + obj: The object to get fields from. Must be an instance of an attrs class or a dataclass. + + Returns: + list: A list of field names. + + Raises: + ValueError: If the object is neither an attrs class nor a dataclass. + """ + if is_dataclass(obj): + return [field.name for field in dataclass_fields(obj)] + elif attr.has(type(obj)): + return [field.name for field in attr.fields(type(obj))] + else: + raise ValueError("The object is neither an attrs class nor a dataclass.") + + +def override(config: Config, overrides: Optional[list[str]] = None) -> Config: + """ + :param config: the instance of class `Config` (usually from `make_config`) + :param overrides: list of overrides for config + :return: the composed instance of class `Config` + """ + # Store the class of the config for reconstruction after overriding. + # config_class = type(config) + + # Convert Config object to a DictConfig object + config_dict = attrs.asdict(config) + config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) + # Enforce "--" separator between the script arguments and overriding configs. + if overrides: + if overrides[0] != "--": + raise ValueError('Hydra config overrides must be separated with a "--" token.') + overrides = overrides[1:] + # Use Hydra to handle overrides + cs = ConfigStore.instance() + cs.store(name="config", node=config_omegaconf) + with initialize(version_base=None): + config_omegaconf = compose(config_name="config", overrides=overrides) + OmegaConf.resolve(config_omegaconf) + + def config_from_dict(ref_instance: Any, kwargs: Any) -> Any: + """ + Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data + + Args: + ref_instance: The reference instance to determine the type and fields when needed + kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data + + Returns: + Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data + + Raises: + AssertionError: If the fields do not match or if extra keys are found. + Exception: If there is an error constructing the new instance. + """ + is_type = is_attrs_or_dataclass(ref_instance) + if not is_type: + return kwargs + else: + ref_fields = set(get_fields(ref_instance)) + assert isinstance(kwargs, dict) or isinstance( + kwargs, DictConfig + ), "kwargs must be a dictionary or a DictConfig" + keys = set(kwargs.keys()) + + # ref_fields must equal to or include all keys + extra_keys = keys - ref_fields + assert ref_fields == keys or keys.issubset( + ref_fields + ), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}" + + resolved_kwargs: Dict[str, Any] = {} + for f in keys: + resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f]) + try: + new_instance = type(ref_instance)(**resolved_kwargs) + except Exception as e: + log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}") + log.error(e) + raise e + return new_instance + + config = config_from_dict(config, config_omegaconf) + + return config + + +def get_config_module(config_file: str) -> str: + if not config_file.endswith(".py"): + log.error("Config file cannot be specified as module.") + log.error("Please provide the path to the Python config file (relative to the Cosmos root).") + + cosmos_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + assert os.path.isfile(config_file) or os.path.isfile(os.path.join(cosmos_root, config_file)), \ + f"Cosmos config file ({config_file}) not found." + + # Convert to importable module format. + config_module = config_file.replace("/", ".").replace(".py", "") + return config_module + + +def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None: + """ + Import all modules from the specified package path recursively. + + This function is typically used in conjunction with Hydra to ensure that all modules + within a specified package are imported, which is necessary for registering configurations. + + Example usage: + ```python + import_all_modules_from_package("cosmos_predict1.diffusion.config.inference", reload=True, skip_underscore=False) + ``` + + Args: + package_path (str): The dotted path to the package from which to import all modules. + reload (bool): Flag to determine whether to reload modules if they're already imported. + skip_underscore (bool): If True, skips importing modules that start with an underscore. + """ + log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}") + package = importlib.import_module(package_path) + package_directory = package.__path__ + + def import_modules_recursively(directory: str, prefix: str) -> None: + """ + Recursively imports or reloads all modules in the given directory. + + Args: + directory (str): The file system path to the current package directory. + prefix (str): The module prefix (e.g., 'models.diffusion.config'). + """ + for _, module_name, is_pkg in pkgutil.iter_modules([directory]): + if skip_underscore and module_name.startswith("_"): + log.debug(f"Skipping module {module_name} as it starts with an underscore") + continue + + full_module_name = f"{prefix}.{module_name}" + log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}") + + if full_module_name in sys.modules and reload: + importlib.reload(sys.modules[full_module_name]) + else: + importlib.import_module(full_module_name) + + if is_pkg: + sub_package_directory = os.path.join(directory, module_name) + import_modules_recursively(sub_package_directory, full_module_name) + + for directory in package_directory: + import_modules_recursively(directory, package_path) diff --git a/cosmos_predict1/utils/ddp_checkpointer.py b/cosmos_predict1/utils/ddp_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..593a5c932c7f274973bbc390cf2386c1fa59df2f --- /dev/null +++ b/cosmos_predict1/utils/ddp_checkpointer.py @@ -0,0 +1,436 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +from collections import namedtuple +from typing import Any, Dict, Optional, Set, Tuple, Union + +import torch +import torch.distributed +from megatron.core import parallel_state +from torch.distributed import ProcessGroup, get_process_group_ranks + +from cosmos_predict1.utils import distributed, log, misc +from cosmos_predict1.utils.base import AbstractCheckpointer +from cosmos_predict1.utils.checkpointer.safe_broadcast import broadcast_object +from cosmos_predict1.utils.easy_io import easy_io +from cosmos_predict1.utils.model import Model + +StateDictItemPath = namedtuple("StateDictItemPath", ["state_dict", "save_path"]) + + +class Checkpointer(AbstractCheckpointer): + """ + Checkpointer for DDP. + Note: This implementation only supports local filesystem. + """ + + KEYS_TO_SAVE = ["model", "optim", "scheduler", "trainer"] + KEYS_TO_POSTFIX = { + "model": "model", + "optim": "optim", + "scheduler": "scheduler", + "trainer": "", + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + ep_world_size = parallel_state.get_expert_model_parallel_world_size() + assert pp_world_size < 2, "Pipeline Parallelism (PP) is not tested yet." + assert ep_world_size < 2, "Expert Parallelism (EP) is not tested yet." + self.mp_world_size = parallel_state.get_model_parallel_group().size() + if self.mp_world_size > 1 and self.__class__ == Checkpointer: + raise NotImplementedError( + "Model Parallelism (MP) is enabled - " + "you should use TensorParallel Checkpointer instead of DDP Checkpointer." + ) + # DDP rank (with context parallelism considered) + self.rank_dp_w_cp = parallel_state.get_data_parallel_rank(with_context_parallel=True) + # Context parallelism rank + self.cp_rank = parallel_state.get_context_parallel_rank() + # Model parallelism rank (including Tensor+Pipeline+Expert Parallelisms) + self.mp_rank = parallel_state.get_model_parallel_group().rank() + # self.mp_rank = parallel_state.get_model_parallel_group(with_expert_parallel=ep_world_size > 1).rank() + if self.broadcast_via_filesystem: + log.info("Broadcasting checkpoint data via the local filesystem.") + if not self.strict_resume: + log.warning("Strict resume mode is off. Some model parameters may not be loaded.") + + # collect ranks of all model parallel groups + all_ranks = [None for _ in range(distributed.get_world_size())] + torch.distributed.all_gather_object( + all_ranks, get_process_group_ranks(parallel_state.get_model_parallel_group()) + ) + all_ranks = list(set(tuple(rank) if isinstance(rank, list) else rank for rank in all_ranks)) + for ranks in all_ranks: + group = torch.distributed.new_group(list(ranks), backend="gloo") + if distributed.get_rank() in ranks: + self.mp_gloo_pg = group + + self.print("Checkpointer Initialized.") + + def print(self, message: str): + """ + Print message to the console. Include the parallelism rank information when verbose is set to True. + """ + if self.verbose: + log.info( + f"[Parallelism Rank: DP-{self.rank_dp_w_cp}, TP-{self.mp_rank}, CP-{self.cp_rank}]: {message}", + rank0_only=False, + ) + else: + log.info(message, rank0_only=True) + + def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: + del model + assert key in self.KEYS_TO_SAVE + post_fix = self.KEYS_TO_POSTFIX[key] + + if post_fix: + _ckpt_path = checkpoint_path.replace(".pt", f"_{post_fix}.pt") + else: + _ckpt_path = checkpoint_path + return _ckpt_path + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + checkpoint_file = self.format_checkpoint_filename(model, iteration) + state_dict = self.generate_save_state_dict(model, optimizer, scheduler, grad_scaler, iteration) + state_dict = self._map_state_dict_path_during_save(state_dict, checkpoint_file, model) + if state_dict: + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker, + daemon=False, + args=(state_dict, checkpoint_file, distributed.get_rank()), + ) + self.save_thread.start() + + # Note: Checkpoints are saved on a separate thread and this callback is not accurate. + # Please check logs from on_save_checkpoint_success() for better accuracy + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + def _map_state_dict_path_during_save(self, state_dict, checkpoint_file, model) -> dict[str, StateDictItemPath]: + new_dict = {} + for key, _state_dict in state_dict.items(): + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_file, model) + checkpoint_path = os.path.join(self.save_dirname, _ckpt_path) + new_dict[key] = StateDictItemPath(_state_dict, checkpoint_path) + return new_dict + + @misc.timer("checkpoint saving") + def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None: + """Worker to save checkpoint to disk, spawned with a child thread (in parallel with the training). + + Args: + state_dict (dict[str, StateDictItemPath]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + try: + for key, item in state_dict.items(): + self.print(f"Saving {key} to {item.save_path}") + try: + easy_io.dump( + item.state_dict, + item.save_path, + fast_backend=True, # optional for fast backend, cpu heavy + ) + self.print(f"Saved {key} to {item.save_path}") + except Exception as e: + self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") + raise # Re-raise the exception after logging + + # Synchronize only rank 0 of each model parallel group + if self.mp_world_size > 1: + torch.distributed.barrier(group=self.mp_gloo_pg) + + # Only rank 0 of MP group and rank 0 of DP with CP updates latest_checkpoint.txt + if self.mp_rank == 0 and self.rank_dp_w_cp == 0: + self._write_latest_checkpoint_file(checkpoint_file) + + if distributed.get_rank() == 0: # only rank 0 saves trained_data_record + if "trained_data_record" in state_dict["model"].state_dict: + self._write_trained_data_record( + checkpoint_file, state_dict["model"].state_dict["trained_data_record"] + ) + + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose) + + def format_checkpoint_filename(self, model: Model, iteration: int) -> str: + """Generate the checkpoint file name. + + Args: + iteration (int): The current iteration number. + + Returns: + checkpoint_file (str): The checkpoint file name. + """ + del self, model + return f"iter_{iteration:09}.pt" + + @misc.timer("generate saving state dict") + def generate_save_state_dict( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + ) -> Optional[Dict[str, Any]]: + state_dict = {} + + if self.rank_dp_w_cp == 0: + trainer_state = dict( + grad_scaler=grad_scaler.state_dict(), + iteration=iteration, + ) + model_state = model.state_dict() + optim_state = optimizer.state_dict() + scheduler_state = scheduler.state_dict() + self.callbacks.on_save_checkpoint(model, state_dict=trainer_state) + + trainer_state, model_state, optim_state, scheduler_state = misc.to( + [trainer_state, model_state, optim_state, scheduler_state], device="cpu" + ) + + state_dict = { + "model": model_state, + "optim": optim_state, + "scheduler": scheduler_state, + } + if distributed.get_rank() == 0: # only rank 0 saves trainer state + state_dict["trainer"] = trainer_state + return state_dict + return state_dict + + def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]: + """ + Load state_dict and broadcast. + + The main steps are: + 1. Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + 2. Each rank loads its corresponding checkpoint from the local cache or receives it via broadcast. + + This approach ensures that each MP rank loads its specific part of the model, which is + crucial for Model Parallelism where different parts of the model are distributed across + multiple GPUs. + + When using Model Parallelism (e.g., Tensor Parallelism), the `broadcast_via_filesystem` option can + be set to True. This allows each rank to load its specific checkpoint from the local filesystem + instead of receiving it via network broadcast, which could be more efficient in some cases. + + For standard DDP without TP, `broadcast_via_filesystem` should remain False (default). + + Args: + checkpoint_path (str): The base path of the checkpoint. + model (Model): The model being loaded. + resume_keys (Set): Set of keys to resume from the checkpoint. + + Returns: + dict[str, Any]: A dictionary containing the loaded state for each resumed key. + """ + state_dict = {} + sorted_resume_keys = sorted(resume_keys) + # Step 1: Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. + if self.rank_dp_w_cp == 0: + for key in sorted_resume_keys: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + if os.path.exists(local_cache_path): + # If the local checkpoint exists, we can directly load it + self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...") + _state_dict = easy_io.load(local_cache_path, fast_backend=True) + else: + _state_dict = easy_io.load(_ckpt_path, fast_backend=True) + self.print(f"Downloading checkpoint from: {_ckpt_path}") + if self.broadcast_via_filesystem: + # Save the checkpoint to the local filesystem + easy_io.dump(_state_dict, local_cache_path, fast_backend=True) + state_dict[key] = _state_dict + # Ensure all ranks wait for the download to complete + distributed.barrier() + + # Step 2: Broadcast checkpoint data + log.info( + "Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.", + rank0_only=True, + ) + for key in sorted_resume_keys: + if self.broadcast_via_filesystem: + # Load the checkpoint from the local filesystem for other ranks + if self.rank_dp_w_cp != 0: + _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) + local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) + self.print(f"Loading checkpoint from: {local_cache_path}") + state_dict[key] = easy_io.load(local_cache_path, fast_backend=True) + else: + # Broadcast the checkpoint to all GPUs of the current DDP rank + group: ProcessGroup = parallel_state.get_data_parallel_group(with_context_parallel=True) + min_rank = min(get_process_group_ranks(group)) + + _state_dict = broadcast_object( + state_dict[key] if self.rank_dp_w_cp == 0 else None, + min_rank, + group=group, + device=torch.device(torch.cuda.current_device()), + ) + if self.rank_dp_w_cp == 0: + self.print(f'Broadcasted checkpoint["{key}"] to all other ranks in the same DDP group.') + else: + state_dict[key] = _state_dict + self.print(f'Received checkpoint["{key}"] from source rank {min_rank}.') + + return state_dict + + def keys_to_resume_during_load(self) -> Tuple[Set, Union[str, None]]: + latest_checkpoint_file = self._read_latest_checkpoint_file() + + resume_keys = [] + + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_path = os.path.join(self.load_dirname, latest_checkpoint_file) + resume_keys.extend(self.KEYS_TO_SAVE) + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + if self.load_training_state: + resume_keys.extend(self.KEYS_TO_SAVE) + else: + resume_keys.append("model") + if self.only_load_scheduler_state: + resume_keys.append("scheduler") + else: + checkpoint_path = None + if len(self.keys_not_to_resume) > 0: + for key in self.keys_not_to_resume: + assert key in self.KEYS_TO_SAVE, f"Invalid key to resume: {key} not in {self.KEYS_TO_SAVE}" + resume_keys = [key for key in resume_keys if key not in self.keys_not_to_resume] + return set(resume_keys), checkpoint_path + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + resume_keys, checkpoint_path = self.keys_to_resume_during_load() + + iteration = 0 + + # Load checkpoint. + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + state_dict = self.load_broadcast_state_dict(checkpoint_path, model, set(resume_keys)) + + if "trainer" in state_dict: + trainer_state = state_dict["trainer"] + log.critical(state_dict.keys(), rank0_only=False) + log.critical(trainer_state, rank0_only=False) + log.info("- Loading the gradient scaler...") + grad_scaler.load_state_dict(trainer_state["grad_scaler"]) + self.callbacks.on_load_checkpoint(model, state_dict=trainer_state) + iteration = trainer_state["iteration"] + if "optim" in state_dict: + assert optimizer + optimizer_state = state_dict["optim"] + log.info("- Loading the optimizer...") + optimizer.load_state_dict(optimizer_state) + if "scheduler" in state_dict: + assert scheduler + scheduler_state = state_dict["scheduler"] + log.info("- Loading the scheduler...") + scheduler.load_state_dict(scheduler_state) + scheduler.last_epoch = iteration + if "model" in state_dict: + model_state = state_dict["model"] + log.info("- Loading the model...") + # model.load_state_dict(model_state) + if self.strict_resume: + log.info("\t Strict resume mode is on.") + else: + log.info("\t Strict resume mode is off.") + model_load_info = model.load_state_dict(model_state, strict=self.strict_resume) + log.info(f"\t {model_load_info}") + self.print(f"Loaded checkpoint from {checkpoint_path} in iteration {iteration}") + else: + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + return iteration + + def _write_trained_data_record(self, checkpoint_file: str, trained_data_record: dict[str, int]) -> None: + """Write json file to save number of seen samples and number of iterations. + + Args: + checkpoint_file (str): iteration number for the saved checkpoint + trained_data_record (dict[str, int]): example {"image": 0, "video": 0, "iteration": 0}. + """ + # filename: iter_xxxxxxxxx_trained_data_record.json + checkpoint_path = os.path.join( + self.save_dirname, f"{checkpoint_file.replace('.pt', '')}_trained_data_record.json" + ) + easy_io.dump(trained_data_record, checkpoint_path) diff --git a/cosmos_predict1/utils/device.py b/cosmos_predict1/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..db486afabd4ae0bf11feb05d8a4efd96690ce64b --- /dev/null +++ b/cosmos_predict1/utils/device.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os + +import pynvml + + +class Device: + """A class to handle NVIDIA GPU device operations using NVML. + + This class provides an interface to access and manage NVIDIA GPU devices, + including retrieving device information and CPU affinity settings. + + Attributes: + _nvml_affinity_elements (int): Number of 64-bit elements needed to represent CPU affinity + """ + + _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore + + def __init__(self, device_idx: int): + """Initialize a Device instance for a specific GPU. + + Args: + device_idx (int): Index of the GPU device to manage + + Raises: + NVMLError: If the device cannot be found or initialized + """ + super().__init__() + self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) + + def get_cpu_affinity(self) -> list[int]: + """Get the CPU affinity mask for this GPU device. + + Retrieves the CPU affinity mask indicating which CPU cores are assigned + to this GPU device. The affinity is returned as a list of CPU core indices. + + Returns: + list[int]: List of CPU core indices that have affinity with this GPU + + Raises: + NVMLError: If the CPU affinity information cannot be retrieved + + Example: + >>> device = Device(0) + >>> device.get_cpu_affinity() + [0, 1, 2, 3] # Shows this GPU has affinity with CPU cores 0-3 + """ + affinity_string = "" + for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements): + # assume nvml returns list of 64 bit ints + affinity_string = "{:064b}".format(j) + affinity_string + affinity_list = [int(x) for x in affinity_string] + affinity_list.reverse() # so core 0 is in 0th element of list + return [i for i, e in enumerate(affinity_list) if e != 0] diff --git a/cosmos_predict1/utils/distributed.py b/cosmos_predict1/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..b827c3bacab1e093dbb586bf1dcffe86ae5fa825 --- /dev/null +++ b/cosmos_predict1/utils/distributed.py @@ -0,0 +1,445 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +import collections.abc +import ctypes +import functools +import os +from contextlib import contextmanager +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Callable, Container, Optional + +import pynvml +import torch +import torch.distributed as dist +from torch.distributed import get_process_group_ranks + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.device import Device + +if TYPE_CHECKING: + from cosmos_predict1.utils.config import DDPConfig + +if dist.is_available(): + from torch.distributed.distributed_c10d import _get_default_group + from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes + + +try: + from megatron.core import parallel_state +except ImportError: + print("Megatron-core is not installed.") + + +def init() -> int | None: + """Initialize distributed training.""" + # Set GPU affinity. + pynvml.nvmlInit() + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = Device(local_rank) + # os.sched_setaffinity(0, device.get_cpu_affinity()) + # Set up NCCL communication. + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0" + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" + if dist.is_available(): + if dist.is_initialized(): + return torch.cuda.current_device() + torch.cuda.set_device(local_rank) + # Get the timeout value from environment variable + timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800) + # Convert the timeout to an integer (if it isn't already) and then to a timedelta + timeout_timedelta = timedelta(seconds=int(timeout_seconds)) + dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) + log.critical( + f"Initialized distributed program with local rank {local_rank} with timeout {timeout_seconds}", + rank0_only=False, + ) + # Increase the L2 fetch granularity for faster speed. + _libcudart = ctypes.CDLL("libcudart.so") + # Set device limit on the current device. + p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) + _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) + _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) + log.info(f"Running with {get_world_size()} GPUs.") + + +def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: + """Get the rank (GPU device) of the worker. + + Returns: + rank (int): The rank of the worker. + """ + rank = 0 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank(group) + return rank + + +def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: + """Get world size. How many GPUs are available in this job. + + Returns: + world_size (int): The total number of GPUs available in this job. + """ + world_size = 1 + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size(group) + return world_size + + +def is_rank0() -> bool: + """Check if current process is the master GPU. + + Returns: + (bool): True if this function is called from the master GPU, else False. + """ + return get_rank() == 0 + + +def is_local_rank0() -> bool: + """Check if current process is the local master GPU in the current node. + + Returns: + (bool): True if this function is called from the local master GPU, else False. + """ + return torch.cuda.current_device() == 0 + + +def device_with_rank(device: str) -> str: + """If the device is 'cuda' and parallelism over GPUs is enabled, returns + Otherwise, returns the device as-is.""" + if device == 'cuda': + return f'cuda:{get_rank()}' + return device + + +def rank0_only(func: Callable) -> Callable: + """Apply this function only to the master GPU. + + Example usage: + @rank0_only + def func(x): + return x + 3 + + Args: + func (Callable): a function. + + Returns: + (Callable): A function wrapper executing the function only on the master GPU. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN202 + if is_rank0(): + return func(*args, **kwargs) + else: + return None + + return wrapper + + +def barrier() -> None: + """Barrier for all GPUs.""" + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + +def rank0_first(func: Callable) -> Callable: + """run the function on rank 0 first, then on other ranks.""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN202 + if is_rank0(): + result = func(*args, **kwargs) + barrier() + if not is_rank0(): + result = func(*args, **kwargs) + return result + + return wrapper + + +def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel: + """Wraps the model to enable data parallalism for training across multiple GPU devices. + + Args: + config_ddp (DDPConfig): The data parallel config. + model (torch.nn.Module): The PyTorch module. + + Returns: + model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper + if distributed environment is available, otherwise return the original model. + """ + if dist.is_available() and dist.is_initialized(): + local_rank = int(os.getenv("LOCAL_RANK", 0)) + try: + ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + except Exception as e: + log.info(e) + log.info("parallel_state not initialized, treating all GPUs equally for DDP") + ddp_group = None + + model = DistributedDataParallel( + model, + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=config_ddp.find_unused_parameters, + static_graph=config_ddp.static_graph, + broadcast_buffers=config_ddp.broadcast_buffers, + process_group=ddp_group, + ) + return model + + +class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): + """This extends torch.nn.parallel.DistributedDataParallel with .training_step(). + + This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that + model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling + model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> + training_step), allowing us to preserve the function names and signatures. + """ + + def __init__(self, model: torch.nn.Module, *args, **kwargs): + super().__init__(model, *args, **kwargs) + self.show_sync_grad_static_graph_warning = True + + def training_step(self, *args, **kwargs) -> Any: + # Cache the original model.forward() method. + original_forward = self.module.forward + + def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202 + # Unpatch immediately before calling training_step() because itself may want to call the real forward. + self.module.forward = original_forward + # The actual .training_step(). + return self.module.training_step(*_args, **_kwargs) + + # Patch the original_module's forward so we can redirect the arguments back to the real method. + self.module.forward = wrapped_training_step + # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step(). + # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed. + return self(*args, **kwargs) + + +@contextmanager +def ddp_sync_grad(model, enabled): + r""" + Context manager to enable/disable gradient synchronizations across DDP processes for DDP model. + Modified from: + https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync + Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True. + + Within this context, gradients will be accumulated on module + variables, which will later be synchronized in the first + forward-backward pass exiting the context. + + .. warning:: + The forward pass should be included inside the context manager, or + else gradients will still be synchronized. + """ + assert isinstance(model, torch.nn.Module) + if isinstance(model, DistributedDataParallel): + old_require_backward_grad_sync = model.require_backward_grad_sync + if model.static_graph and model.require_backward_grad_sync != enabled: + if model.show_sync_grad_static_graph_warning: + log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.") + model.show_sync_grad_static_graph_warning = False + else: + model.require_backward_grad_sync = enabled + try: + yield + finally: + if isinstance(model, DistributedDataParallel): + model.require_backward_grad_sync = old_require_backward_grad_sync + + +def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: + """Aggregate the list of data batches from all devices and process the results. + + This is used for gathering validation data batches with utils.dataloader.DistributedEvalSampler. + It will return the data/output of the entire validation set in its original index order. The sizes of data_batches + in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be + created before calling dis.all_gather(). + + Args: + data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where + leaf entries are tensors. + + Returns: + data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where + leaf entries are concatenated tensors. + """ + if isinstance(data_batches[0], torch.Tensor): + # Concatenate the local data batches. + data_concat = torch.cat(data_batches, dim=0) # type: ignore + # Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank. + max_num_local_samples = torch.tensor(len(data_concat), device="cuda") + dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX) + if len(data_concat) < max_num_local_samples: + assert len(data_concat) + 1 == max_num_local_samples + dummy = torch.empty_like(data_concat[:1]) + data_concat = torch.cat([data_concat, dummy], dim=0) + dummy_count = torch.tensor(1, device="cuda") + else: + dummy_count = torch.tensor(0, device="cuda") + # Get all concatenated batches from all ranks and concatenate again. + dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM) + data_concat = all_gather_tensor(data_concat.contiguous()) + data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1) + # Remove the dummy samples. + if dummy_count > 0: + data_collate = data_collate[:-dummy_count] + elif isinstance(data_batches[0], collections.abc.Mapping): + data_collate = dict() + for key in data_batches[0].keys(): + data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore + else: + raise TypeError + return data_collate + + +@torch.no_grad() +def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]: + """Gather the corresponding tensor from all GPU devices to a list. + + Args: + tensor (torch.Tensor): Pytorch tensor. + + Returns: + tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices. + """ + tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] + dist.all_gather(tensor_list, tensor) + return tensor_list + + +def broadcast(tensor, src, group=None, async_op=False): + world_size = get_world_size() + if world_size < 2: + return tensor + dist.broadcast(tensor, src=src, group=group, async_op=async_op) + + +def sync_model_states( + model: torch.nn.Module, + process_group: Optional[dist.ProcessGroup] = None, + src: int = 0, + params_and_buffers_to_ignore: Optional[Container[str]] = None, + broadcast_buffers: bool = True, +): + """ + Modify based on DDP source code + Synchronizes the parameters and buffers of a model across different processes in a distributed setting. + + This function ensures that all processes in the specified process group have the same initial parameters and + buffers from the source rank, typically rank 0. It is useful when different processes start with different model + states and a synchronization is required to ensure consistency across all ranks. + + Args: + model (nn.Module): The model whose parameters and buffers are to be synchronized. + process_group (dist.ProcessGroup, optional): The process group for communication. If None, + the default group is used. Defaults to None. + src (int, optional): The source rank from which parameters and buffers will be broadcasted. + Defaults to 0. + params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer + names to exclude from synchronization. Defaults to None, which means all parameters and buffers are + included. + broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True. + + Side Effects: + This function modifies the state of the model in-place to synchronize it with the source rank's model state. + + Raises: + RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised. + + Examples: + >>> # downloading duplicated model weights from s3 in each rank and save network bandwidth + >>> # useful and save our time when model weights are huge + >>> if dist.get_rank == 0: + >>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path)) + >>> dist.barrir() + >>> sync_model_states(model) # sync rank0 weights to other ranks + """ + if process_group is None: + process_group = _get_default_group() + if not params_and_buffers_to_ignore: + params_and_buffers_to_ignore = set() + + log.info( + f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}." + ) + + # Build tuple of (module, parameter) for all parameters that require grads. + modules_and_parameters = [ + (module, parameter) + for module_name, module in model.named_modules() + for parameter in [ + param + # Note that we access module.named_parameters instead of + # parameters(module). parameters(module) is only needed in the + # single-process multi device case, where it accesses replicated + # parameters through _former_parameters. + for param_name, param in module.named_parameters(recurse=False) + if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore + # if param.requires_grad + # and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore + ] + ] + + # Deduplicate any parameters that might be shared across child modules. + memo = set() + modules_and_parameters = [ + # "p not in memo" is the deduplication check. + # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed. + (m, p) + for m, p in modules_and_parameters + if p not in memo and not memo.add(p) # type: ignore[func-returns-value] + ] + + # Build list of parameters. + parameters = [parameter for _, parameter in modules_and_parameters] + if len(parameters) == 0: + return + + _verify_param_shape_across_processes(process_group, parameters) + + _sync_module_states( + module=model, + process_group=process_group, + broadcast_bucket_size=int(250 * 1024 * 1024), + src=src, + params_and_buffers_to_ignore=params_and_buffers_to_ignore, + broadcast_buffers=broadcast_buffers, + ) + + +def dist_reduce_tensor(tensor, rank=0, reduce="mean"): + r"""Reduce to rank 0""" + world_size = get_world_size() + if world_size < 2: + return tensor + with torch.no_grad(): + dist.reduce(tensor, dst=rank) + if get_rank() == rank: + if reduce == "mean": + tensor /= world_size + elif reduce == "sum": + pass + else: + raise NotImplementedError + return tensor diff --git a/cosmos_predict1/utils/easy_io/__init__.py b/cosmos_predict1/utils/easy_io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3159bfe65645499015bd92609b99d476d69544e9 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict1/utils/easy_io/backends/__init__.py b/cosmos_predict1/utils/easy_io/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c85c64735eb12bc6ed1e45b7681684efa0dbace --- /dev/null +++ b/cosmos_predict1/utils/easy_io/backends/__init__.py @@ -0,0 +1,13 @@ +from cosmos_predict1.utils.easy_io.backends.base_backend import BaseStorageBackend +from cosmos_predict1.utils.easy_io.backends.http_backend import HTTPBackend +from cosmos_predict1.utils.easy_io.backends.local_backend import LocalBackend +from cosmos_predict1.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend + +__all__ = [ + "BaseStorageBackend", + "LocalBackend", + "HTTPBackend", + "register_backend", + "backends", + "prefix_to_backends", +] diff --git a/cosmos_predict1/utils/easy_io/backends/base_backend.py b/cosmos_predict1/utils/easy_io/backends/base_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..2db3b921f0b6fdb3aaea867c0bb3cafdb5e59888 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/backends/base_backend.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +from abc import ABCMeta, abstractmethod + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == "": + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def has_method(obj, method): + return hasattr(obj, method) and callable(getattr(obj, method)) + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: :meth:`get()` and + :meth:`get_text()`. + + - :meth:`get()` reads the file as a byte stream. + - :meth:`get_text()` reads the file as texts. + """ + + # a flag to indicate whether the backend can create a symlink for a file + # This attribute will be deprecated in future. + _allow_symlink = False + + @property + def allow_symlink(self): + return self._allow_symlink + + @property + def name(self): + return self.__class__.__name__ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass diff --git a/cosmos_predict1/utils/easy_io/backends/http_backend.py b/cosmos_predict1/utils/easy_io/backends/http_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4c14c481c91e8c551ca898237bae39229ecd82 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/backends/http_backend.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Union +from urllib.request import urlopen + +from cosmos_predict1.utils.easy_io.backends.base_backend import BaseStorageBackend + + +class HTTPBackend(BaseStorageBackend): + """HTTP and HTTPS storage bachend.""" + + def get(self, filepath: str) -> bytes: + """Read bytes from a given ``filepath``. + + Args: + filepath (str): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = HTTPBackend() + >>> backend.get('http://path/of/file') + b'hello world' + """ + return urlopen(filepath).read() + + def get_text(self, filepath, encoding="utf-8") -> str: + """Read text from a given ``filepath``. + + Args: + filepath (str): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = HTTPBackend() + >>> backend.get_text('http://path/of/file') + 'hello world' + """ + return urlopen(filepath).read().decode(encoding) + + @contextmanager + def get_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath`` to a local temporary directory, + and return the temporary path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Yields: + Iterable[str]: Only yield one temporary path. + + Examples: + >>> backend = HTTPBackend() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with backend.get_local_path('http://path/of/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.get(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) diff --git a/cosmos_predict1/utils/easy_io/backends/local_backend.py b/cosmos_predict1/utils/easy_io/backends/local_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..2d712bb53ddd844e20350c236dd5cfb999b60fc9 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/backends/local_backend.py @@ -0,0 +1,550 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import os.path as osp +import shutil +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Iterator, Optional, Tuple, Union + +from cosmos_predict1.utils.easy_io.backends.base_backend import BaseStorageBackend, mkdir_or_exist + + +class LocalBackend(BaseStorageBackend): + """Raw local storage backend.""" + + _allow_symlink = True + + def get(self, filepath: Union[str, Path]) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.get(filepath) + b'hello world' + """ + with open(filepath, "rb") as f: + value = f.read() + return value + + def get_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.get_text(filepath) + 'hello world' + """ + with open(filepath, encoding=encoding) as f: + text = f.read() + return text + + def put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path]) -> None: + """Write bytes to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.put(b'hello world', filepath) + """ + mkdir_or_exist(osp.dirname(filepath)) + if isinstance(obj, io.BytesIO): + obj.seek(0) + obj = obj.getvalue() + with open(filepath, "wb") as f: + f.write(obj) + + def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None: + """Write text to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.put_text('hello world', filepath) + """ + mkdir_or_exist(osp.dirname(filepath)) + with open(filepath, "w", encoding=encoding) as f: + f.write(obj) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.exists(filepath) + True + """ + return osp.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/dir' + >>> backend.isdir(filepath) + True + """ + return osp.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.isfile(filepath) + True + """ + return osp.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: + r"""Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of \*filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + + Examples: + >>> backend = LocalBackend() + >>> filepath1 = '/path/of/dir1' + >>> filepath2 = 'dir2' + >>> filepath3 = 'path/of/file' + >>> backend.join_path(filepath1, filepath2, filepath3) + '/path/of/dir/dir2/path/of/file' + """ + return osp.join(filepath, *filepaths) + + @contextmanager + def get_local_path( + self, + filepath: Union[str, Path], + ) -> Generator[Union[str, Path], None, None]: + """Only for unified API and does nothing. + + Args: + filepath (str or Path): Path to be read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> backend = LocalBackend() + >>> with backend.get_local_path('abc/def.jpg') as path: + ... # do something here + """ + yield filepath + + def copyfile( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to '/path1/of/dir/file' + >>> backend.copyfile(src, dst) + '/path1/of/dir/file' + """ + return shutil.copy(src, dst) + + def copytree( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. + + src and dst should have the same prefix and dst must not already exist. + + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will + be raised. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree(src, dst) + '/path/of/dir2' + """ + return shutil.copytree(src, dst) + + def copyfile_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Copy a local file src to dst and return the destination file. Same + as :meth:`copyfile`. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile_from_local(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to + >>> backend.copyfile_from_local(src, dst) + '/path1/of/dir/file' + """ + return self.copyfile(src, dst) + + def copytree_from_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a + directory named dst and return the destination directory. Same as + :meth:`copytree`. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + + Returns: + str: The destination directory. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree_from_local(src, dst) + '/path/of/dir2' + """ + return self.copytree(src, dst) + + def copyfile_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + dst_type: Optional[str] = None, + ) -> str: + """Copy the file src to local dst and return the destination file. Same + as :meth:`copyfile`. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> backend = LocalBackend() + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> backend.copyfile_to_local(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to + >>> backend.copyfile_to_local(src, dst) + '/path1/of/dir/file' + """ + return self.copyfile(src, dst) + + def copytree_to_local( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> str: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> backend.copytree_from_local(src, dst) + '/path/of/dir2' + """ + return self.copytree(src, dst) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str or Path): Path to be removed. + + Raises: + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + + Examples: + >>> backend = LocalBackend() + >>> filepath = '/path/of/file' + >>> backend.remove(filepath) + """ + if not self.exists(filepath): + raise FileNotFoundError(f"filepath {filepath} does not exist") + + if self.isdir(filepath): + raise IsADirectoryError("filepath should be a file") + + os.remove(filepath) + + def rmtree(self, dir_path: Union[str, Path]) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + + Examples: + >>> dir_path = '/path/of/dir' + >>> backend.rmtree(dir_path) + """ + shutil.rmtree(dir_path) + + def copy_if_symlink_fails( + self, + src: Union[str, Path], + dst: Union[str, Path], + ) -> bool: + """Create a symbolic link pointing to src named dst. + + If failed to create a symbolic link pointing to src, directly copy src + to dst instead. + + Args: + src (str or Path): Create a symbolic link pointing to src. + dst (str or Path): Create a symbolic link named dst. + + Returns: + bool: Return True if successfully create a symbolic link pointing + to src. Otherwise, return False. + + Examples: + >>> backend = LocalBackend() + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> backend.copy_if_symlink_fails(src, dst) + True + >>> src = '/path/of/dir' + >>> dst = '/path1/of/dir1' + >>> backend.copy_if_symlink_fails(src, dst) + True + """ + try: + os.symlink(src, dst) + return True + except Exception: + if self.isfile(src): + self.copyfile(src, dst) + else: + self.copytree(src, dst) + return False + + def list_dir_or_file( + self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + ) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str or Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix that we are + interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the directory. + Defaults to False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> backend = LocalBackend() + >>> dir_path = '/path/of/dir' + >>> # list those files and directories in current directory + >>> for file_path in backend.list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): + ... print(file_path) + """ # noqa: E501 + if list_dir and suffix is not None: + raise TypeError("`suffix` should be None when `list_dir` is True") + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError("`suffix` must be a string or tuple of strings") + + root = dir_path + + def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith(".") and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if (suffix is None or rel_path.endswith(suffix)) and list_file: + yield rel_path + elif osp.isdir(entry.path): + if list_dir: + rel_dir = osp.relpath(entry.path, root) + yield rel_dir + if recursive: + yield from _list_dir_or_file(entry.path, list_dir, list_file, suffix, recursive) + + return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/cosmos_predict1/utils/easy_io/backends/registry_utils.py b/cosmos_predict1/utils/easy_io/backends/registry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dd70c3e548455f362b36cb9803693fa1ab5fbdbe --- /dev/null +++ b/cosmos_predict1/utils/easy_io/backends/registry_utils.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Optional, Type, Union + +from cosmos_predict1.utils.easy_io.backends.base_backend import BaseStorageBackend +from cosmos_predict1.utils.easy_io.backends.http_backend import HTTPBackend +from cosmos_predict1.utils.easy_io.backends.local_backend import LocalBackend + +backends: dict = {} +prefix_to_backends: dict = {} + + +def _register_backend( + name: str, + backend: Type[BaseStorageBackend], + force: bool = False, + prefixes: Union[str, list, tuple, None] = None, +): + """Register a backend. + + Args: + name (str): The name of the registered backend. + backend (BaseStorageBackend): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + force (bool): Whether to override the backend if the name has already + been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefix + of the registered storage backend. Defaults to None. + """ + global backends, prefix_to_backends + + if not isinstance(name, str): + raise TypeError("the backend name should be a string, " f"but got {type(name)}") + + if not inspect.isclass(backend): + raise TypeError(f"backend should be a class, but got {type(backend)}") + if not issubclass(backend, BaseStorageBackend): + raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") + + if name in backends and not force: + raise ValueError( + f"{name} is already registered as a storage backend, " 'add "force=True" if you want to override it' + ) + backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + + for prefix in prefixes: + if prefix in prefix_to_backends and not force: + raise ValueError( + f"{prefix} is already registered as a storage backend," + ' add "force=True" if you want to override it' + ) + + prefix_to_backends[prefix] = backend + + +def register_backend( + name: str, + backend: Optional[Type[BaseStorageBackend]] = None, + force: bool = False, + prefixes: Union[str, list, tuple, None] = None, +): + """Register a backend. + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool): Whether to override the backend if the name has already + been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefix + of the registered storage backend. Defaults to None. + + This method can be used as a normal method or a decorator. + + Examples: + + >>> class NewBackend(BaseStorageBackend): + ... def get(self, filepath): + ... return filepath + ... + ... def get_text(self, filepath): + ... return filepath + >>> register_backend('new', NewBackend) + + >>> @register_backend('new') + ... class NewBackend(BaseStorageBackend): + ... def get(self, filepath): + ... return filepath + ... + ... def get_text(self, filepath): + ... return filepath + """ + if backend is not None: + _register_backend(name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + _register_backend(name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + +register_backend("local", LocalBackend, prefixes="") +register_backend("http", HTTPBackend, prefixes=["http", "https"]) diff --git a/cosmos_predict1/utils/easy_io/easy_io.py b/cosmos_predict1/utils/easy_io/easy_io.py new file mode 100644 index 0000000000000000000000000000000000000000..de7189abf9def860d77bbbc778eb76658df41a2a --- /dev/null +++ b/cosmos_predict1/utils/easy_io/easy_io.py @@ -0,0 +1,1066 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import warnings +from contextlib import contextmanager +from io import BytesIO, StringIO +from pathlib import Path +from typing import IO, Any, Generator, Iterator, Optional, Tuple, Union + +from cosmos_predict1.utils.easy_io.backends import backends, prefix_to_backends +from cosmos_predict1.utils.easy_io.file_client import FileClient +from cosmos_predict1.utils.easy_io.handlers import file_handlers + +backend_instances: dict = {} + + +def is_filepath(filepath): + return isinstance(filepath, (str, Path)) + + +def _parse_uri_prefix(uri: Union[str, Path]) -> str: + """Parse the prefix of uri. + + Args: + uri (str or Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> _parse_uri_prefix('/home/path/of/your/file') + '' + >>> _parse_uri_prefix('s3://path/of/your/file') + 's3' + >>> _parse_uri_prefix('clusterName:s3://path/of/your/file') + 's3' + + Returns: + str: Return the prefix of uri if the uri contains '://'. Otherwise, + return ''. + """ + assert is_filepath(uri) + uri = str(uri) + # if uri does not contains '://', the uri will be handled by + # LocalBackend by default + if "://" not in uri: + return "" + else: + prefix, _ = uri.split("://") + if ":" in prefix: + _, prefix = prefix.split(":") + return prefix + + +def _get_file_backend(prefix: str, backend_args: dict): + """Return a file backend based on the prefix or backend_args. + + Args: + prefix (str): Prefix of uri. + backend_args (dict): Arguments to instantiate the corresponding + backend. + """ + # backend name has a higher priority + if "backend" in backend_args: + # backend_args should not be modified + backend_args_bak = backend_args.copy() + backend_name = backend_args_bak.pop("backend") + backend = backends[backend_name](**backend_args_bak) + else: + backend = prefix_to_backends[prefix](**backend_args) + return backend + + +def get_file_backend( + uri: Union[str, Path, None] = None, + *, + backend_args: Optional[dict] = None, + enable_singleton: bool = False, + backend_key: Optional[str] = None, +): + """Return a file backend based on the prefix of uri or backend_args. + + Args: + uri (str or Path): Uri to be parsed that contains the file prefix. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + enable_singleton (bool): Whether to enable the singleton pattern. + If it is True, the backend created will be reused if the + signature is same with the previous one. Defaults to False. + backend_key: str: The key to register the backend. Defaults to None. + + Returns: + BaseStorageBackend: Instantiated Backend object. + + Examples: + >>> # get file backend based on the prefix of uri + >>> uri = 's3://path/of/your/file' + >>> backend = get_file_backend(uri) + >>> # get file backend based on the backend_args + >>> backend = get_file_backend(backend_args={'backend': 's3'}) + >>> # backend name has a higher priority if 'backend' in backend_args + >>> backend = get_file_backend(uri, backend_args={'backend': 's3'}) + """ + global backend_instances + if backend_key is not None: + if backend_key in backend_instances: + return backend_instances[backend_key] + + if backend_args is None: + backend_args = {} + + if uri is None and "backend" not in backend_args and backend_key is None: + raise ValueError( + 'uri should not be None when "backend" does not exist in ' "backend_args and backend_key is None" + ) + + if uri is not None: + prefix = _parse_uri_prefix(uri) + else: + prefix = "" + + if enable_singleton: + unique_key = f"{prefix}:{json.dumps(backend_args)}" + if unique_key in backend_instances: + return backend_instances[unique_key] + + backend = _get_file_backend(prefix, backend_args) + backend_instances[unique_key] = backend + if backend_key is not None: + backend_instances[backend_key] = backend + return backend + else: + backend = _get_file_backend(prefix, backend_args) + return backend + + +def get( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bytes: + """Read bytes from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bytes: Expected bytes object. + + Examples: + >>> filepath = '/path/of/file' + >>> get(filepath) + b'hello world' + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.get(filepath) + + +def get_text( + filepath: Union[str, Path], + encoding="utf-8", + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> str: + """Read text from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + str: Expected text reading from ``filepath``. + + Examples: + >>> filepath = '/path/of/file' + >>> get_text(filepath) + 'hello world' + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.get_text(filepath, encoding) + + +def put( + obj: bytes, + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Write bytes to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Examples: + >>> filepath = '/path/of/file' + >>> put(b'hello world', filepath) + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.put(obj, filepath) + + +def put_text( + obj: str, + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Write text to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + ``filepath``. Defaults to 'utf-8'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Examples: + >>> filepath = '/path/of/file' + >>> put_text('hello world', filepath) + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.put_text(obj, filepath) + + +def exists( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + + Examples: + >>> filepath = '/path/of/file' + >>> exists(filepath) + True + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.exists(filepath) + + +def isdir( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + + Examples: + >>> filepath = '/path/of/dir' + >>> isdir(filepath) + True + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.isdir(filepath) + + +def isfile( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + + Examples: + >>> filepath = '/path/of/file' + >>> isfile(filepath) + True + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.isfile(filepath) + + +def join_path( + filepath: Union[str, Path], + *filepaths: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + r"""Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of \*filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + *filepaths (str or Path): Other paths to be concatenated. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + str: The result of concatenation. + + Examples: + >>> filepath1 = '/path/of/dir1' + >>> filepath2 = 'dir2' + >>> filepath3 = 'path/of/file' + >>> join_path(filepath1, filepath2, filepath3) + '/path/of/dir/dir2/path/of/file' + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + return backend.join_path(filepath, *filepaths) + + +@contextmanager +def get_local_path( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Generator[Union[str, Path], None, None]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself and it will + not be released (removed). + + Args: + filepath (str or Path): Path to be read data. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Yields: + Iterable[str]: Only yield one path. + + Examples: + >>> with get_local_path('abc/def.jpg') as path: + ... # do something here + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + with backend.get_local_path(str(filepath)) as local_path: + yield local_path + + +def copyfile( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Copy a file src to dst and return the destination file. + + src and dst should have the same prefix. If dst specifies a directory, + the file will be copied into dst using the base filename from src. If + dst specifies a file that already exists, it will be replaced. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination file. + + Raises: + SameFileError: If src and dst are the same file, a SameFileError will + be raised. + + Examples: + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> # src will be copied to '/path1/of/file1' + >>> copyfile(src, dst) + '/path1/of/file1' + + >>> # dst is a directory + >>> dst = '/path1/of/dir' + >>> # src will be copied to '/path1/of/dir/file' + >>> copyfile(src, dst) + '/path1/of/dir/file' + """ + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copyfile(src, dst) + + +def copytree( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a directory + named dst and return the destination directory. + + src and dst should have the same prefix and dst must not already exist. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + backend_key (str, optional): The key to get the backend from register. + + Returns: + str: The destination directory. + + Raises: + FileExistsError: If dst had already existed, a FileExistsError will be + raised. + + Examples: + >>> src = '/path/of/dir1' + >>> dst = '/path/of/dir2' + >>> copytree(src, dst) + '/path/of/dir2' + """ + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copytree(src, dst) + + +def copyfile_from_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Copy a local file src to dst and return the destination file. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copyfile`. + + Args: + src (str or Path): A local file to be copied. + dst (str or Path): Copy file to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> # dst is a file + >>> src = '/path/of/file' + >>> dst = 's3://openmmlab/mmengine/file1' + >>> # src will be copied to 's3://openmmlab/mmengine/file1' + >>> copyfile_from_local(src, dst) + s3://openmmlab/mmengine/file1 + + >>> # dst is a directory + >>> dst = 's3://openmmlab/mmengine' + >>> # src will be copied to 's3://openmmlab/mmengine/file'' + >>> copyfile_from_local(src, dst) + 's3://openmmlab/mmengine/file' + """ + backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copyfile_from_local(src, dst) + + +def copytree_from_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a directory + named dst and return the destination directory. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copytree`. + + Args: + src (str or Path): A local directory to be copied. + dst (str or Path): Copy directory to dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> src = '/path/of/dir' + >>> dst = 's3://openmmlab/mmengine/dir' + >>> copyfile_from_local(src, dst) + 's3://openmmlab/mmengine/dir' + """ + backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copytree_from_local(src, dst) + + +def copyfile_to_local( + src: Union[str, Path], + dst: Union[str, Path], + dst_type: str, # Choose from ["file", "dir"] + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Copy the file src to local dst and return the destination file. + + If dst specifies a directory, the file will be copied into dst using + the base filename from src. If dst specifies a file that already + exists, it will be replaced. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copyfile`. + + Args: + src (str or Path): A file to be copied. + dst (str or Path): Copy file to to local dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: If dst specifies a directory, the file will be copied into dst + using the base filename from src. + + Examples: + >>> # dst is a file + >>> src = 's3://openmmlab/mmengine/file' + >>> dst = '/path/of/file' + >>> # src will be copied to '/path/of/file' + >>> copyfile_to_local(src, dst) + '/path/of/file' + + >>> # dst is a directory + >>> dst = '/path/of/dir' + >>> # src will be copied to '/path/of/dir/file' + >>> copyfile_to_local(src, dst) + '/path/of/dir/file' + """ + assert dst_type in ["file", "dir"] + Path(dst).parent.mkdir(parents=True, exist_ok=True) + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copyfile_to_local(src, dst, dst_type=dst_type) + + +def copytree_to_local( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Union[str, Path]: + """Recursively copy an entire directory tree rooted at src to a local + directory named dst and return the destination directory. + + Note: + If the backend is the instance of LocalBackend, it does the same + thing with :func:`copytree`. + + Args: + src (str or Path): A directory to be copied. + dst (str or Path): Copy directory to local dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: The destination directory. + + Examples: + >>> src = 's3://openmmlab/mmengine/dir' + >>> dst = '/path/of/dir' + >>> copytree_to_local(src, dst) + '/path/of/dir' + """ + Path(dst).parent.mkdir(parents=True, exist_ok=True) + backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copytree_to_local(src, dst) + + +def remove( + filepath: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Raises: + FileNotFoundError: If filepath does not exist, an FileNotFoundError + will be raised. + IsADirectoryError: If filepath is a directory, an IsADirectoryError + will be raised. + + Examples: + >>> filepath = '/path/of/file' + >>> remove(filepath) + """ + backend = get_file_backend( + filepath, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.remove(filepath) + + +def rmtree( + dir_path: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> None: + """Recursively delete a directory tree. + + Args: + dir_path (str or Path): A directory to be removed. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Examples: + >>> dir_path = '/path/of/dir' + >>> rmtree(dir_path) + """ + backend = get_file_backend( + dir_path, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + backend.rmtree(dir_path) + + +def copy_if_symlink_fails( + src: Union[str, Path], + dst: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> bool: + """Create a symbolic link pointing to src named dst. + + If failed to create a symbolic link pointing to src, directory copy src to + dst instead. + + Args: + src (str or Path): Create a symbolic link pointing to src. + dst (str or Path): Create a symbolic link named dst. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + bool: Return True if successfully create a symbolic link pointing to + src. Otherwise, return False. + + Examples: + >>> src = '/path/of/file' + >>> dst = '/path1/of/file1' + >>> copy_if_symlink_fails(src, dst) + True + >>> src = '/path/of/dir' + >>> dst = '/path1/of/dir1' + >>> copy_if_symlink_fails(src, dst) + True + """ + backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.copy_if_symlink_fails(src, dst) + + +def list_dir( + dir_path: Union[str, Path], + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +): + """List all folders in an S3 bucket with a given prefix. + + Args: + dir_path (str | Path): Path of the directory. + + Examples: + >>> dir_path = '/path/of/dir' + >>> for file_path in list_dir(dir_path): + ... print(file_path) + """ + if not dir_path.endswith("/"): + dir_path += "/" + backend = get_file_backend( + dir_path, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + + return backend.list_dir(dir_path) + + +def list_dir_or_file( + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str or Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix that we are + interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the directory. + Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + + Examples: + >>> dir_path = '/path/of/dir' + >>> for file_path in list_dir_or_file(dir_path): + ... print(file_path) + >>> # list those files and directories in current directory + >>> for file_path in list_dir_or_file(dir_path): + ... print(file_path) + >>> # only list files + >>> for file_path in list_dir_or_file(dir_path, list_dir=False): + ... print(file_path) + >>> # only list directories + >>> for file_path in list_dir_or_file(dir_path, list_file=False): + ... print(file_path) + >>> # only list files ending with specified suffixes + >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'): + ... print(file_path) + >>> # list all files and directory recursively + >>> for file_path in list_dir_or_file(dir_path, recursive=True): + ... print(file_path) + """ + backend = get_file_backend( + dir_path, + backend_args=backend_args, + enable_singleton=True, + backend_key=backend_key, + ) + yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) + + +def generate_presigned_url( + url: str, + client_method: str = "get_object", + expires_in: int = 3600, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> str: + """Generate the presigned url of video stream which can be passed to + mmcv.VideoReader. Now only work on s3 backend. + + Note: + Now only work on s3 backend. + + Args: + url (str): Url of video stream. + client_method (str): Method of client, 'get_object' or + 'put_object'. Defaults to 'get_object'. + expires_in (int): expires, in seconds. Defaults to 3600. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + + Returns: + str: Generated presigned url. + """ + backend = get_file_backend(url, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) + return backend.generate_presigned_url(url, client_method, expires_in) + + +def load( + file: Union[str, Path, IO[Any]], + file_format: Optional[str] = None, + file_client_args: Optional[dict] = None, + fast_backend: bool = False, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, + **kwargs, +): + """Load data from json/yaml/pickle files. + + This method provides a unified api for loading data from serialized files. + + ``load`` supports loading data from serialized files those can be storaged + in different backends. + + Args: + file (str or :obj:`Path` or file-like object): Filename or a file-like + object. + file_format (str, optional): If not specified, the file format will be + inferred from the file extension, otherwise use the specified one. + Currently supported formats include "json", "yaml/yml" and + "pickle/pkl". + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + fast_backend: bool: Whether to use multiprocess. Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + New in v0.2.0. + + Examples: + >>> load('/path/of/your/file') # file is storaged in disk + >>> load('https://path/of/your/file') # file is storaged in Internet + >>> load('s3://path/of/your/file') # file is storaged in s3 + + Returns: + The content from the file. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None and isinstance(file, str): + file_format = file.split(".")[-1] + # convert file_format to lower case + file_format = file_format.lower() + if file_format not in file_handlers: + raise TypeError(f"Unsupported format: {file_format}") + + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', + DeprecationWarning, + ) + if backend_args is not None: + raise ValueError('"file_client_args and "backend_args" cannot be set at the ' "same time.") + + handler = file_handlers[file_format] + if isinstance(file, str): + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, file) + file_backend = file_client + else: + file_backend = get_file_backend( + file, + backend_args=backend_args, + backend_key=backend_key, + enable_singleton=True, + ) + + if handler.str_like: + with StringIO(file_backend.get_text(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + if fast_backend: + if hasattr(file_backend, "fast_get"): + with BytesIO(file_backend.fast_get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + warnings.warn( + f"fast_backend is not supported by the backend, type {type(file_backend)} fallback to normal get" + ) + with BytesIO(file_backend.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + else: + with BytesIO(file_backend.get(file)) as f: + obj = handler.load_from_fileobj(f, **kwargs) + elif hasattr(file, "read"): + obj = handler.load_from_fileobj(file, **kwargs) + else: + raise TypeError('"file" must be a filepath str or a file-object') + return obj + + +def dump( + obj: Any, + file: Union[str, Path, IO[Any], None] = None, + file_format: Optional[str] = None, + file_client_args: Optional[dict] = None, + fast_backend: bool = False, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, + **kwargs, +): + """Dump data to json/yaml/pickle strings or files. + + This method provides a unified api for dumping data as strings or to files, + and also supports custom arguments for each file format. + + ``dump`` supports dumping data as strings or to files which is saved to + different backends. + + Args: + obj (any): The python object to be dumped. + file (str or :obj:`Path` or file-like object, optional): If not + specified, then the object is dumped to a str, otherwise to a file + specified by the filename or file-like object. + file_format (str, optional): Same as :func:`load`. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to None. It will be deprecated in future. Please use + ``backend_args`` instead. + fast_backend: bool: Whether to use multiprocess. Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + New in v0.2.0. + backend_key: str: The key to register the backend. Defaults to None. + + Examples: + >>> dump('hello world', '/path/of/your/file') # disk + >>> dump('hello world', 's3://path/of/your/file') # ceph or s3 + + Returns: + bool: True for success, False otherwise. + """ + if isinstance(file, Path): + file = str(file) + if file_format is None: + if isinstance(file, str): + file_format = file.split(".")[-1] + elif file is None: + raise ValueError("file_format must be specified since file is None") + # convert file_format to lower case + file_format = file_format.lower() + if file_format not in file_handlers: + raise TypeError(f"Unsupported format: {file_format}") + + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', + DeprecationWarning, + ) + if backend_args is not None: + raise ValueError('"file_client_args" and "backend_args" cannot be set at the ' "same time.") + + handler = file_handlers[file_format] + if file is None: + return handler.dump_to_str(obj, **kwargs) + elif isinstance(file, str): + if file_client_args is not None: + file_client = FileClient.infer_client(file_client_args, file) + file_backend = file_client + else: + file_backend = get_file_backend( + file, + backend_args=backend_args, + backend_key=backend_key, + enable_singleton=True, + ) + + if handler.str_like: + with StringIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + file_backend.put_text(f.getvalue(), file) + else: + with BytesIO() as f: + handler.dump_to_fileobj(obj, f, **kwargs) + if fast_backend: + if hasattr(file_backend, "fast_put"): + file_backend.fast_put(f, file) + else: + warnings.warn("fast_backend is not supported by the backend, fallback to normal put") + file_backend.put(f, file) + else: + file_backend.put(f, file) + elif hasattr(file, "write"): + handler.dump_to_fileobj(obj, file, **kwargs) + else: + raise TypeError('"file" must be a filename str or a file-object') diff --git a/cosmos_predict1/utils/easy_io/file_client.py b/cosmos_predict1/utils/easy_io/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..8c963e39515e494a9e4df3288baf68d90769d292 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/file_client.py @@ -0,0 +1,450 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Generator, Iterator, Optional, Tuple, Union + +from cosmos_predict1.utils.easy_io.backends import BaseStorageBackend, HTTPBackend, LocalBackend + + +def is_filepath(filepath): + return isinstance(filepath, (str, Path)) + + +class HardDiskBackend(LocalBackend): + """Raw hard disks storage backend.""" + + @property + def name(self): + return self.__class__.__name__ + + +class FileClient: + """A general file client to access files in different backends. + + The client loads a file or text in a specified backend from its path + and returns it as a binary or text file. There are two ways to choose a + backend, the name of backend and the prefix of path. Although both of them + can be used to choose a storage backend, ``backend`` has a higher priority + that is if they are all set, the storage backend will be chosen by the + backend argument. If they are all `None`, the disk backend will be chosen. + Note that It can also register other backend accessor with a given name, + prefixes, and backend class. In addition, We use the singleton pattern to + avoid repeated object creation. If the arguments are the same, the same + object will be returned. + + Warning: + `FileClient` will be deprecated in future. Please use io functions + in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io + + Args: + backend (str, optional): The storage backend type. Options are "disk", + "memcached", "lmdb", "http" and "s3". Defaults to None. + prefix (str, optional): The prefix of the registered storage backend. + Options are "s3", "http", "https". Defaults to None. + + Examples: + >>> # only set backend + >>> file_client = FileClient(backend='s3') + >>> # only set prefix + >>> file_client = FileClient(prefix='s3') + >>> # set both backend and prefix but use backend to choose client + >>> file_client = FileClient(backend='s3', prefix='s3') + >>> # if the arguments are the same, the same object is returned + >>> file_client1 = FileClient(backend='s3') + >>> file_client1 is file_client + True + + Attributes: + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + "disk": HardDiskBackend, + "http": HTTPBackend, + } + + _prefix_to_backends: dict = { + "http": HTTPBackend, + "https": HTTPBackend, + } + + _instances: dict = {} + + client: Any + + def __new__(cls, backend=None, prefix=None, **kwargs): + if backend is None and prefix is None: + backend = "disk" + if backend is not None and backend not in cls._backends: + raise ValueError( + f"Backend {backend} is not supported. Currently supported ones" f" are {list(cls._backends.keys())}" + ) + if prefix is not None and prefix not in cls._prefix_to_backends: + raise ValueError( + f"prefix {prefix} is not supported. Currently supported ones " + f"are {list(cls._prefix_to_backends.keys())}" + ) + + # concatenate the arguments to a unique key for determining whether + # objects with the same arguments were created + arg_key = f"{backend}:{prefix}" + for key, value in kwargs.items(): + arg_key += f":{key}:{value}" + + # if a backend was overridden, it will create a new object + if arg_key in cls._instances: + _instance = cls._instances[arg_key] + else: + # create a new object and put it to _instance + _instance = super().__new__(cls) + if backend is not None: + _instance.client = cls._backends[backend](**kwargs) + else: + _instance.client = cls._prefix_to_backends[prefix](**kwargs) + + cls._instances[arg_key] = _instance + + return _instance + + @property + def name(self): + return self.client.name + + @property + def allow_symlink(self): + return self.client.allow_symlink + + @staticmethod + def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: + """Parse the prefix of a uri. + + Args: + uri (str | Path): Uri to be parsed that contains the file prefix. + + Examples: + >>> FileClient.parse_uri_prefix('s3://path/of/your/file') + 's3' + + Returns: + str | None: Return the prefix of uri if the uri contains '://' else + ``None``. + """ + assert is_filepath(uri) + uri = str(uri) + if "://" not in uri: + return None + else: + prefix, _ = uri.split("://") + if ":" in prefix: + _, prefix = prefix.split(":") + return prefix + + @classmethod + def infer_client( + cls, + file_client_args: Optional[dict] = None, + uri: Optional[Union[str, Path]] = None, + ) -> "FileClient": + """Infer a suitable file client based on the URI and arguments. + + Args: + file_client_args (dict, optional): Arguments to instantiate a + FileClient. Defaults to None. + uri (str | Path, optional): Uri to be parsed that contains the file + prefix. Defaults to None. + + Examples: + >>> uri = 's3://path/of/your/file' + >>> file_client = FileClient.infer_client(uri=uri) + >>> file_client_args = {'backend': 's3'} + >>> file_client = FileClient.infer_client(file_client_args) + + Returns: + FileClient: Instantiated FileClient object. + """ + assert file_client_args is not None or uri is not None + if file_client_args is None: + file_prefix = cls.parse_uri_prefix(uri) # type: ignore + return cls(prefix=file_prefix) + else: + return cls(**file_client_args) + + @classmethod + def _register_backend(cls, name, backend, force=False, prefixes=None): + if not isinstance(name, str): + raise TypeError("the backend name should be a string, " f"but got {type(name)}") + if not inspect.isclass(backend): + raise TypeError(f"backend should be a class but got {type(backend)}") + if not issubclass(backend, BaseStorageBackend): + raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") + if not force and name in cls._backends: + raise KeyError( + f"{name} is already registered as a storage backend, " 'add "force=True" if you want to override it' + ) + + if name in cls._backends and force: + for arg_key, instance in list(cls._instances.items()): + if isinstance(instance.client, cls._backends[name]): + cls._instances.pop(arg_key) + cls._backends[name] = backend + + if prefixes is not None: + if isinstance(prefixes, str): + prefixes = [prefixes] + else: + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if prefix not in cls._prefix_to_backends: + cls._prefix_to_backends[prefix] = backend + elif (prefix in cls._prefix_to_backends) and force: + overridden_backend = cls._prefix_to_backends[prefix] + for arg_key, instance in list(cls._instances.items()): + if isinstance(instance.client, overridden_backend): + cls._instances.pop(arg_key) + else: + raise KeyError( + f"{prefix} is already registered as a storage backend," + ' add "force=True" if you want to override it' + ) + + @classmethod + def register_backend(cls, name, backend=None, force=False, prefixes=None): + """Register a backend to FileClient. + + This method can be used as a normal class method or a decorator. + + .. code-block:: python + + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + FileClient.register_backend('new', NewBackend) + + or + + .. code-block:: python + + @FileClient.register_backend('new') + class NewBackend(BaseStorageBackend): + + def get(self, filepath): + return filepath + + def get_text(self, filepath): + return filepath + + Args: + name (str): The name of the registered backend. + backend (class, optional): The backend class to be registered, + which must be a subclass of :class:`BaseStorageBackend`. + When this method is used as a decorator, backend is None. + Defaults to None. + force (bool, optional): Whether to override the backend if the name + has already been registered. Defaults to False. + prefixes (str or list[str] or tuple[str], optional): The prefixes + of the registered storage backend. Defaults to None. + `New in version 1.3.15.` + """ + if backend is not None: + cls._register_backend(name, backend, force=force, prefixes=prefixes) + return + + def _register(backend_cls): + cls._register_backend(name, backend_cls, force=force, prefixes=prefixes) + return backend_cls + + return _register + + def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: + """Read data from a given ``filepath`` with 'rb' mode. + + Note: + There are two types of return values for ``get``, one is ``bytes`` + and the other is ``memoryview``. The advantage of using memoryview + is that you can avoid copying, and if you want to convert it to + ``bytes``, you can use ``.tobytes()``. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes | memoryview: Expected bytes object or a memory view of the + bytes object. + """ + return self.client.get(filepath) + + def get_text(self, filepath: Union[str, Path], encoding="utf-8") -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Defaults to 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + return self.client.get_text(filepath, encoding) + + def put(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``put`` should create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + self.client.put(obj, filepath) + + def put_text(self, obj: str, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``put_text`` should create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str, optional): The encoding format used to open the + `filepath`. Defaults to 'utf-8'. + """ + self.client.put_text(obj, filepath) + + def remove(self, filepath: Union[str, Path]) -> None: + """Remove a file. + + Args: + filepath (str, Path): Path to be removed. + """ + self.client.remove(filepath) + + def exists(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path exists. + + Args: + filepath (str or Path): Path to be checked whether exists. + + Returns: + bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. + """ + return self.client.exists(filepath) + + def isdir(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a directory. + + Args: + filepath (str or Path): Path to be checked whether it is a + directory. + + Returns: + bool: Return ``True`` if ``filepath`` points to a directory, + ``False`` otherwise. + """ + return self.client.isdir(filepath) + + def isfile(self, filepath: Union[str, Path]) -> bool: + """Check whether a file path is a file. + + Args: + filepath (str or Path): Path to be checked whether it is a file. + + Returns: + bool: Return ``True`` if ``filepath`` points to a file, ``False`` + otherwise. + """ + return self.client.isfile(filepath) + + def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: + r"""Concatenate all file paths. + + Join one or more filepath components intelligently. The return value + is the concatenation of filepath and any members of \*filepaths. + + Args: + filepath (str or Path): Path to be concatenated. + + Returns: + str: The result of concatenation. + """ + return self.client.join_path(filepath, *filepaths) + + @contextmanager + def get_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]: + """Download data from ``filepath`` and write the data to local path. + + ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Note: + If the ``filepath`` is a local path, just return itself. + + .. warning:: + ``get_local_path`` is an experimental interface that may change in + the future. + + Args: + filepath (str or Path): Path to be read data. + + Examples: + >>> file_client = FileClient(prefix='s3') + >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: + ... # do something here + + Yields: + Iterable[str]: Only yield one path. + """ + with self.client.get_local_path(str(filepath)) as local_path: + yield local_path + + def list_dir_or_file( # pylint: disable=too-many-arguments + self, + dir_path: Union[str, Path], + list_dir: bool = True, + list_file: bool = True, + suffix: Optional[Union[str, Tuple[str]]] = None, + recursive: bool = False, + ) -> Iterator[str]: + """Scan a directory to find the interested directories or files in + arbitrary order. + + Note: + :meth:`list_dir_or_file` returns the path relative to ``dir_path``. + + Args: + dir_path (str | Path): Path of the directory. + list_dir (bool): List the directories. Defaults to True. + list_file (bool): List the path of files. Defaults to True. + suffix (str or tuple[str], optional): File suffix + that we are interested in. Defaults to None. + recursive (bool): If set to True, recursively scan the + directory. Defaults to False. + + Yields: + Iterable[str]: A relative path to ``dir_path``. + """ + yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/cosmos_predict1/utils/easy_io/handlers/__init__.py b/cosmos_predict1/utils/easy_io/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c2d900319026db39d87cc206881c40df9aedb97 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler +from cosmos_predict1.utils.easy_io.handlers.json_handler import JsonHandler +from cosmos_predict1.utils.easy_io.handlers.pickle_handler import PickleHandler +from cosmos_predict1.utils.easy_io.handlers.registry_utils import file_handlers, register_handler +from cosmos_predict1.utils.easy_io.handlers.yaml_handler import YamlHandler + +__all__ = [ + "BaseFileHandler", + "JsonHandler", + "PickleHandler", + "YamlHandler", + "register_handler", + "file_handlers", +] diff --git a/cosmos_predict1/utils/easy_io/handlers/base.py b/cosmos_predict1/utils/easy_io/handlers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5dcbcabc40807706eeb43d1a598571c51922a8 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/base.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABCMeta, abstractmethod + + +class BaseFileHandler(metaclass=ABCMeta): + # `str_like` is a flag to indicate whether the type of file object is + # str-like object or bytes-like object. Pickle only processes bytes-like + # objects but json only processes str-like object. If it is str-like + # object, `StringIO` will be used to process the buffer. + str_like = True + + @abstractmethod + def load_from_fileobj(self, file, **kwargs): + pass + + @abstractmethod + def dump_to_fileobj(self, obj, file, **kwargs): + pass + + @abstractmethod + def dump_to_str(self, obj, **kwargs): + pass + + def load_from_path(self, filepath, mode="r", **kwargs): + with open(filepath, mode) as f: + return self.load_from_fileobj(f, **kwargs) + + def dump_to_path(self, obj, filepath, mode="w", **kwargs): + with open(filepath, mode) as f: + self.dump_to_fileobj(obj, f, **kwargs) diff --git a/cosmos_predict1/utils/easy_io/handlers/csv_handler.py b/cosmos_predict1/utils/easy_io/handlers/csv_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..58d6493be50257de285669727f62a61372b1cf0a --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/csv_handler.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +from io import StringIO + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class CsvHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + del kwargs + reader = csv.reader(file) + return list(reader) + + def dump_to_fileobj(self, obj, file, **kwargs): + del kwargs + writer = csv.writer(file) + if not all(isinstance(row, list) for row in obj): + raise ValueError("Each row must be a list") + writer.writerows(obj) + + def dump_to_str(self, obj, **kwargs): + del kwargs + output = StringIO() + writer = csv.writer(output) + if not all(isinstance(row, list) for row in obj): + raise ValueError("Each row must be a list") + writer.writerows(obj) + return output.getvalue() diff --git a/cosmos_predict1/utils/easy_io/handlers/gzip_handler.py b/cosmos_predict1/utils/easy_io/handlers/gzip_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..205f6abb2a002438bd072dd86e21ea845c35c8bd --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/gzip_handler.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +import pickle +from io import BytesIO +from typing import Any + +from cosmos_predict1.utils.easy_io.handlers.pickle_handler import PickleHandler + + +class GzipHandler(PickleHandler): + str_like = False + + def load_from_fileobj(self, file: BytesIO, **kwargs): + with gzip.GzipFile(fileobj=file, mode="rb") as f: + return pickle.load(f) + + def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): + with gzip.GzipFile(fileobj=file, mode="wb") as f: + pickle.dump(obj, f) diff --git a/cosmos_predict1/utils/easy_io/handlers/imageio_video_handler.py b/cosmos_predict1/utils/easy_io/handlers/imageio_video_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..30551d176fc842e71d91e9a62c336a60244f289a --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/imageio_video_handler.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import IO + +import numpy as np +import torch + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + +try: + import imageio +except ImportError: + imageio = None + + +class ImageioVideoHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs): + """ + Load video from a file-like object using imageio with specified format and color mode. + + Parameters: + file (IO[bytes]): A file-like object containing video data. + format (str): Format of the video file (default 'mp4'). + mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb'). + + Returns: + tuple: A tuple containing an array of video frames and metadata about the video. + """ + file.seek(0) + video_reader = imageio.get_reader(file, format, **kwargs) + + video_frames = [] + for frame in video_reader: + if mode == "gray": + import cv2 # Convert frame to grayscale if mode is gray + + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = np.expand_dims(frame, axis=2) # Keep frame dimensions consistent + video_frames.append(frame) + + return np.array(video_frames), video_reader.get_meta_data() + + def dump_to_fileobj( + self, + obj: np.ndarray | torch.Tensor, + file: IO[bytes], + format: str = "mp4", # pylint: disable=redefined-builtin + fps: int = 17, + quality: int = 5, + **kwargs, + ): + """ + Save an array of video frames to a file-like object using imageio. + + Parameters: + obj (np.ndarray): An array of frames to be saved as video. + file (IO[bytes]): A file-like object to which the video data will be written. + format (str): Format of the video file (default 'mp4'). + fps (int): Frames per second of the output video (default 30). + + """ + if isinstance(obj, torch.Tensor): + assert obj.dtype == torch.uint8 + obj = obj.cpu().numpy() + h, w = obj.shape[1:-1] + kwargs = { + "fps": fps, + "quality": quality, + "macro_block_size": 1, + "ffmpeg_params": ["-s", f"{w}x{h}"], + "output_params": ["-f", "mp4"], + } + imageio.mimsave(file, obj, format, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_predict1/utils/easy_io/handlers/json_handler.py b/cosmos_predict1/utils/easy_io/handlers/json_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe9ffbe2aa20c4ea467bcdec29c8e7c2917c473 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/json_handler.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import numpy as np + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f"{type(obj)} is unsupported for json dump") + + +class JsonHandler(BaseFileHandler): + def load_from_fileobj(self, file): + return json.load(file) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault("default", set_default) + json.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("default", set_default) + return json.dumps(obj, **kwargs) diff --git a/cosmos_predict1/utils/easy_io/handlers/jsonl_handler.py b/cosmos_predict1/utils/easy_io/handlers/jsonl_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..b30ce6b1959b05268a842409274e1251a4765672 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/jsonl_handler.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import IO + +import numpy as np + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, + etc.) into plain numbers of plain python built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f"{type(obj)} is unsupported for json dump") + + +class JsonlHandler(BaseFileHandler): + """Handler for JSON lines (JSONL) files.""" + + def load_from_fileobj(self, file: IO[bytes]): + """Load JSON objects from a newline-delimited JSON (JSONL) file object. + + Returns: + A list of Python objects loaded from each JSON line. + """ + data = [] + for line in file: + line = line.strip() + if not line: + continue # skip empty lines if any + data.append(json.loads(line)) + return data + + def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs): + """Dump a list of objects to a newline-delimited JSON (JSONL) file object. + + Args: + obj: A list (or iterable) of objects to dump line by line. + """ + kwargs.setdefault("default", set_default) + for item in obj: + file.write(json.dumps(item, **kwargs) + "\n") + + def dump_to_str(self, obj, **kwargs): + """Dump a list of objects to a newline-delimited JSON (JSONL) string.""" + kwargs.setdefault("default", set_default) + lines = [json.dumps(item, **kwargs) for item in obj] + return "\n".join(lines) + + +if __name__ == "__main__": + from cosmos_predict1.utils.easy_io import easy_io + + easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl") + print(easy_io.load("test.jsonl")) + easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl") + print(easy_io.load("test.jsonl")) diff --git a/cosmos_predict1/utils/easy_io/handlers/np_handler.py b/cosmos_predict1/utils/easy_io/handlers/np_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..29cb8d55c181c14a06022d97523166b3763cb753 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/np_handler.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import BytesIO +from typing import IO, Any + +import numpy as np + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class NumpyHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any: + """ + Load a NumPy array from a file-like object. + + Parameters: + file (IO[bytes]): The file-like object containing the NumPy array data. + **kwargs: Additional keyword arguments passed to `np.load`. + + Returns: + numpy.ndarray: The loaded NumPy array. + """ + return np.load(file, **kwargs) + + def load_from_path(self, filepath: str, **kwargs) -> Any: + """ + Load a NumPy array from a file path. + + Parameters: + filepath (str): The path to the file to load. + **kwargs: Additional keyword arguments passed to `np.load`. + + Returns: + numpy.ndarray: The loaded NumPy array. + """ + return super().load_from_path(filepath, mode="rb", **kwargs) + + def dump_to_str(self, obj: np.ndarray, **kwargs) -> str: + """ + Serialize a NumPy array to a string in binary format. + + Parameters: + obj (np.ndarray): The NumPy array to serialize. + **kwargs: Additional keyword arguments passed to `np.save`. + + Returns: + str: The serialized NumPy array as a string. + """ + with BytesIO() as f: + np.save(f, obj, **kwargs) + return f.getvalue() + + def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs): + """ + Dump a NumPy array to a file-like object. + + Parameters: + obj (np.ndarray): The NumPy array to dump. + file (IO[bytes]): The file-like object to which the array is dumped. + **kwargs: Additional keyword arguments passed to `np.save`. + """ + np.save(file, obj, **kwargs) + + def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs): + """ + Dump a NumPy array to a file path. + + Parameters: + obj (np.ndarray): The NumPy array to dump. + filepath (str): The file path where the array should be saved. + **kwargs: Additional keyword arguments passed to `np.save`. + """ + with open(filepath, "wb") as f: + np.save(f, obj, **kwargs) diff --git a/cosmos_predict1/utils/easy_io/handlers/pandas_handler.py b/cosmos_predict1/utils/easy_io/handlers/pandas_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..cdcac6e6eb82f8e92c79c4fef16e1f5b68dbd82c --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/pandas_handler.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler # isort:skip + + +class PandasHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return pd.read_csv(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + obj.to_csv(file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError("PandasHandler does not support dumping to str") diff --git a/cosmos_predict1/utils/easy_io/handlers/pickle_handler.py b/cosmos_predict1/utils/easy_io/handlers/pickle_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..66bc10d5f2da24785cef4216e2961747c52eb756 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/pickle_handler.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +from io import BytesIO +from typing import Any + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class PickleHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file: BytesIO, **kwargs): + return pickle.load(file, **kwargs) + + def load_from_path(self, filepath, **kwargs): + return super().load_from_path(filepath, mode="rb", **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("protocol", 2) + return pickle.dumps(obj, **kwargs) + + def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): + kwargs.setdefault("protocol", 2) + pickle.dump(obj, file, **kwargs) + + def dump_to_path(self, obj, filepath, **kwargs): + with open(filepath, "wb") as f: + pickle.dump(obj, f, **kwargs) diff --git a/cosmos_predict1/utils/easy_io/handlers/pil_handler.py b/cosmos_predict1/utils/easy_io/handlers/pil_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..a473bf486ce87df49e9f48afc6efec81abe8cbc4 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/pil_handler.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import IO, Optional, Tuple, Union + +import numpy as np + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + +try: + from PIL import Image +except ImportError: + Image = None + + +class PILHandler(BaseFileHandler): + format: str + str_like = False + + def load_from_fileobj( + self, + file: IO[bytes], + fmt: str = "pil", + size: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ): + """ + Load an image from a file-like object and return it in a specified format. + + Args: + file (IO[bytes]): A file-like object containing the image data. + fmt (str): The format to convert the image into. Options are \ + 'numpy', 'np', 'npy', 'type' (all return numpy arrays), \ + 'pil' (returns PIL Image), 'th', 'torch' (returns a torch tensor). + size (Optional[Union[int, Tuple[int, int]]]): The new size of the image as a single integer \ + or a tuple of (width, height). If specified, the image is resized accordingly. + **kwargs: Additional keyword arguments that can be passed to conversion functions. + + Returns: + Image data in the format specified by `fmt`. + + Raises: + IOError: If the image cannot be loaded or processed. + ValueError: If the specified format is unsupported. + """ + try: + img = Image.open(file) + img.load() # Explicitly load the image data + if size is not None: + if isinstance(size, int): + size = ( + size, + size, + ) # create a tuple if only one integer is provided + img = img.resize(size, Image.ANTIALIAS) + + # Return the image in the requested format + if fmt in ["numpy", "np", "npy"]: + return np.array(img, **kwargs) + if fmt == "pil": + return img + if fmt in ["th", "torch"]: + import torch + + # Convert to tensor + img_tensor = torch.from_numpy(np.array(img, **kwargs)) + # Convert image from HxWxC to CxHxW + if img_tensor.ndim == 3: + img_tensor = img_tensor.permute(2, 0, 1) + return img_tensor + raise ValueError( + "Unsupported format. Supported formats are 'numpy', 'np', 'npy', 'pil', 'th', and 'torch'." + ) + except Exception as e: + raise IOError(f"Unable to load image: {e}") from e + + def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs): + if "format" not in kwargs: + kwargs["format"] = self.format + kwargs["format"] = "JPEG" if self.format.lower() == "jpg" else self.format.upper() + obj.save(file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_predict1/utils/easy_io/handlers/registry_utils.py b/cosmos_predict1/utils/easy_io/handlers/registry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec7edccc3d7c445d7034c028950cd823a17aef8b --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/registry_utils.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler +from cosmos_predict1.utils.easy_io.handlers.csv_handler import CsvHandler +from cosmos_predict1.utils.easy_io.handlers.gzip_handler import GzipHandler +from cosmos_predict1.utils.easy_io.handlers.imageio_video_handler import ImageioVideoHandler +from cosmos_predict1.utils.easy_io.handlers.json_handler import JsonHandler +from cosmos_predict1.utils.easy_io.handlers.jsonl_handler import JsonlHandler +from cosmos_predict1.utils.easy_io.handlers.np_handler import NumpyHandler +from cosmos_predict1.utils.easy_io.handlers.pandas_handler import PandasHandler +from cosmos_predict1.utils.easy_io.handlers.pickle_handler import PickleHandler +from cosmos_predict1.utils.easy_io.handlers.pil_handler import PILHandler +from cosmos_predict1.utils.easy_io.handlers.tarfile_handler import TarHandler +from cosmos_predict1.utils.easy_io.handlers.torch_handler import TorchHandler +from cosmos_predict1.utils.easy_io.handlers.torchjit_handler import TorchJitHandler +from cosmos_predict1.utils.easy_io.handlers.txt_handler import TxtHandler +from cosmos_predict1.utils.easy_io.handlers.yaml_handler import YamlHandler + +file_handlers = { + "json": JsonHandler(), + "yaml": YamlHandler(), + "yml": YamlHandler(), + "pickle": PickleHandler(), + "pkl": PickleHandler(), + "tar": TarHandler(), + "jit": TorchJitHandler(), + "npy": NumpyHandler(), + "txt": TxtHandler(), + "csv": CsvHandler(), + "pandas": PandasHandler(), + "gz": GzipHandler(), + "jsonl": JsonlHandler(), +} + +for torch_type in ["pt", "pth", "ckpt"]: + file_handlers[torch_type] = TorchHandler() +for img_type in ["jpg", "jpeg", "png", "bmp", "gif"]: + file_handlers[img_type] = PILHandler() + file_handlers[img_type].format = img_type +for video_type in ["mp4", "avi", "mov", "webm", "flv", "wmv"]: + file_handlers[video_type] = ImageioVideoHandler() + + +def _register_handler(handler, file_formats): + """Register a handler for some file extensions. + + Args: + handler (:obj:`BaseFileHandler`): Handler to be registered. + file_formats (str or list[str]): File formats to be handled by this + handler. + """ + if not isinstance(handler, BaseFileHandler): + raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}") + if isinstance(file_formats, str): + file_formats = [file_formats] + if not all([isinstance(item, str) for item in file_formats]): + raise TypeError("file_formats must be a str or a list of str") + for ext in file_formats: + file_handlers[ext] = handler + + +def register_handler(file_formats, **kwargs): + def wrap(cls): + _register_handler(cls(**kwargs), file_formats) + return cls + + return wrap diff --git a/cosmos_predict1/utils/easy_io/handlers/tarfile_handler.py b/cosmos_predict1/utils/easy_io/handlers/tarfile_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..687e8a8a3adafeaf88b8a2644472943698dffe26 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/tarfile_handler.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tarfile + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class TarHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, mode="r|*", **kwargs): + return tarfile.open(fileobj=file, mode=mode, **kwargs) + + def load_from_path(self, filepath, mode="r|*", **kwargs): + return tarfile.open(filepath, mode=mode, **kwargs) + + def dump_to_fileobj(self, obj, file, mode="w", **kwargs): + with tarfile.open(fileobj=file, mode=mode) as tar: + tar.add(obj, **kwargs) + + def dump_to_path(self, obj, filepath, mode="w", **kwargs): + with tarfile.open(filepath, mode=mode) as tar: + tar.add(obj, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_predict1/utils/easy_io/handlers/torch_handler.py b/cosmos_predict1/utils/easy_io/handlers/torch_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f64eafe59a9593c4e8aed6513092e9604faae378 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/torch_handler.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import torch +except ImportError: + torch = None + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class TorchHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return torch.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + torch.save(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_predict1/utils/easy_io/handlers/torchjit_handler.py b/cosmos_predict1/utils/easy_io/handlers/torchjit_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..71598cdaf2679ed47293dcc0410be28b9b4b0a91 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/torchjit_handler.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import torch +except ImportError: + torch = None + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class TorchJitHandler(BaseFileHandler): + str_like = False + + def load_from_fileobj(self, file, **kwargs): + return torch.jit.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + torch.jit.save(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + raise NotImplementedError diff --git a/cosmos_predict1/utils/easy_io/handlers/txt_handler.py b/cosmos_predict1/utils/easy_io/handlers/txt_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..007d7f661d0d5ad001207a0424c35e51aea9e1a9 --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/txt_handler.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler + + +class TxtHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + del kwargs + return file.read() + + def dump_to_fileobj(self, obj, file, **kwargs): + del kwargs + if not isinstance(obj, str): + obj = str(obj) + file.write(obj) + + def dump_to_str(self, obj, **kwargs): + del kwargs + if not isinstance(obj, str): + obj = str(obj) + return obj diff --git a/cosmos_predict1/utils/easy_io/handlers/yaml_handler.py b/cosmos_predict1/utils/easy_io/handlers/yaml_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ede0280f41d7fa29468d46141ff31e038dbcad4c --- /dev/null +++ b/cosmos_predict1/utils/easy_io/handlers/yaml_handler.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml + +try: + from yaml import CDumper as Dumper # type: ignore + from yaml import CLoader as Loader # type: ignore +except ImportError: + from yaml import Loader, Dumper # type: ignore + +from cosmos_predict1.utils.easy_io.handlers.base import BaseFileHandler # isort:skip + + +class YamlHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + kwargs.setdefault("Loader", Loader) + return yaml.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault("Dumper", Dumper) + yaml.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault("Dumper", Dumper) + return yaml.dump(obj, **kwargs) diff --git a/cosmos_predict1/utils/ema.py b/cosmos_predict1/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..7d883648a2ca969fb11e61591c01e85f8d488513 --- /dev/null +++ b/cosmos_predict1/utils/ema.py @@ -0,0 +1,327 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union + +import numpy as np +import torch +from megatron.core import parallel_state + +from cosmos_predict1.utils import distributed, log + +if TYPE_CHECKING: + from cosmos_predict1.utils.model import Model + + +class FastEmaModelUpdater: + """ + This class is used to update target model~(EMA) given source model~(regular model) and beta. + The method interaface mimic :class:`EMAModelTracker` and :class:`PowerEMATracker`. + Different from two classes, this class does not maintain the EMA model weights as buffers. It expects the user to have two module with same architecture and weights shape. + The class is proposed to work with FSDP model where above two classes are not working as expected. Besides, it is strange to claim model weights as buffers and do unnecessary name changing in :class:`EMAModelTracker` and :class:`PowerEMATracker`. Moeving forward, we should use this class instead of above two classes. + """ + + def __init__(self): + # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite + self.is_cached = False + + def update_average(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module, beta: float = 0.9999) -> None: + target_list = [] + source_list = [] + for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): + assert ( + tgt_params.dtype == torch.float32 + ), f"EMA model only works in FP32 dtype, got {tgt_params.dtype} instead." + target_list.append(tgt_params) + source_list.append(src_params.data) + torch._foreach_mul_(target_list, beta) + torch._foreach_add_(target_list, source_list, alpha=1.0 - beta) + + def copy_to(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module) -> None: + for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): + tgt_params.data.copy_(src_params.data) + + def cache(self, parameters: Any, is_cpu: bool = False) -> None: + """Save the current parameters for restoring later. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored. + """ + assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" + device = "cpu" if is_cpu else "cuda" + self.collected_params = [param.clone().to(device) for param in parameters] + self.is_cached = True + + def restore(self, parameters: Any) -> None: + """Restore the parameters in self.collected_params. + + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before copy_to(). + After validation (or model saving), use this to restore the former parameters. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters. + """ + assert self.is_cached, "EMA cache is not taken yet." + for c_param, param in zip(self.collected_params, parameters, strict=False): + param.data.copy_(c_param.data.type_as(param.data)) + self.collected_params = [] + # Release the cache after we call restore + self.is_cached = False + + +def get_buffer_name(param_name: str, torch_compile_buffer_renaming: bool = False) -> str: + """ + This function creates buffer name used by EMA from parameter's name + + Args: + param_name (str): Model's parameter name + Returns: + buffer_name (str): buffer name to be used for given parameter name + """ + + buffer_name = param_name.replace(".", "-") + + if torch_compile_buffer_renaming: + # torch.compile() adds _orig_mod to state dict names, this way we get original name + buffer_name = buffer_name.replace("_orig_mod-", "") + + return buffer_name + + +class EMAModelTracker(torch.nn.Module): + """This is a class to track the EMA model weights. + + The EMA weights are registered as buffers, which are extractable as state dicts. The names follow those of the + regular weights, except all "." are replaced with "-" (limitation of register_buffer()). This is similar to SDXL's + implementation of EMA. There are no optimizable parameters. + + Attributes: + collected_params (list): temporarily stores the regular weights while in EMA mode. + beta (float): EMA decay rate. (default: 0.9999). + torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used + """ + + def __init__(self, model: Model, beta: float = 0.9999, torch_compile_buffer_renaming: bool = False): + """Constructor of the EMA model weight tracker. + + Args: + model (Model): The PyTorch model. + beta (float): EMA decay rate. (default: 0.9999). + """ + super().__init__() + self.torch_compile_buffer_renaming: bool = torch_compile_buffer_renaming + if not 0.0 <= beta <= 1.0: + raise ValueError("Decay must be between 0 and 1") + self.beta = beta + for name, param in model.named_parameters(): + if param.requires_grad: + buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) + self.register_buffer(buffer_name, param.clone().detach().data) + self.collected_params = [] + # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite + self.is_cached = False + + @torch.no_grad() + def update_average(self, model: Model, iteration: Optional[int] = None) -> None: + del iteration + target_list = [] + source_list = [] + ema_buffers = self.state_dict() + for name, param in model.named_parameters(): + if param.requires_grad: + buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) + buffer = ema_buffers[buffer_name] + assert buffer.dtype == torch.float32, f"EMA model only works in FP32 dtype, got {buffer.dtype} instead." + target_list.append(buffer) + source_list.append(param.data) + torch._foreach_mul_(target_list, self.beta) + torch._foreach_add_(target_list, source_list, alpha=1.0 - self.beta) + + def copy_to(self, model: Model) -> None: + ema_buffers = self.state_dict() + for name, param in model.named_parameters(): + if param.requires_grad: + buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) + buffer = ema_buffers[buffer_name] + param.data.copy_(buffer.data) + + def cache(self, parameters: Any, is_cpu: bool = False) -> None: + """Save the current parameters for restoring later. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored. + """ + assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" + device = "cpu" if is_cpu else "cuda" + self.collected_params = [param.clone().to(device) for param in parameters] + self.is_cached = True + + def restore(self, parameters: Any) -> None: + """Restore the parameters in self.collected_params. + + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before copy_to(). + After validation (or model saving), use this to restore the former parameters. + + Args: + parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters. + """ + assert self.is_cached, "EMA cache is not taken yet." + for c_param, param in zip(self.collected_params, parameters, strict=False): + param.data.copy_(c_param.data.type_as(param.data)) + self.collected_params = [] + # Release the cache after we call restore + self.is_cached = False + + @classmethod + def initialize_multi_rank_ema( + cls, model: torch.nn.Module, rate: Union[float, List[float]], num: int = 1, enabled: bool = True + ) -> Optional[EMAModelTracker]: + """ + Class method to initialize per rank EMA Model Tracker with different rate. + Each rank will have a different rate based on the given configuration, resulting in different EMA weights. + + Args: + model (torch.nn.Module): The neural network model to be tracked. + rate (Union[float, List[float]]): The decay rate(s) for the EMA. If a list is provided, + it corresponds to rates for different ranks. + num (int, optional): The number of leading ranks to consider for different rates. + Defaults to 1. + enabled (bool, optional): Flag to enable or disable the creation of the tracker. + If False, returns None. Defaults to True. + + Returns: + Optional[EMAModelTracker]: An instance of EMAModelTracker if enabled, otherwise None. + + Example: + >>> model = torch.nn.Linear(10, 2) + >>> tracker = EMAModelTracker.initialize_ema_from_settings(model, rate=[0.1, 0.2], num=2) + >>> print(tracker) + + Notes: + If `rate` is a list and the current rank is less than `num`, the rate for the current rank + is used. If the current rank exceeds `num`, the first rate in the list is used by default. + """ + if not enabled: + return None + if parallel_state.is_initialized(): + cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + log.warning("It should not used together with FSDP!") + else: + cur_dp_rank = distributed.get_rank() + log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + rate = rate if isinstance(rate, list) else [rate] + num = min(num, len(rate)) + rate = rate[cur_dp_rank] if cur_dp_rank < num else rate[0] + if cur_dp_rank < num: + print(f"EMAModelTracker: rank {cur_dp_rank}, rate {rate}") + return cls(model, rate) + + +class PowerEMATracker(EMAModelTracker): + def __init__(self, model: Model, s: float = 0.1, torch_compile_buffer_renaming: bool = False): + """Constructor of the EMA model weight tracker. + + Args: + model (Model): The PyTorch model. + s (float): EMA decay rate. See EDM2 paper + torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used + """ + super().__init__(model=model, beta=0.0, torch_compile_buffer_renaming=torch_compile_buffer_renaming) + self.exp = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() + + @torch.no_grad() + def update_average(self, model: Model, iteration: Optional[int] = None) -> None: + if iteration == 0: + beta = 0.0 + else: + i = iteration + 1 + beta = (1 - 1 / i) ** (self.exp + 1) + self.beta = beta + + super().update_average(model, iteration) + + @classmethod + def initialize_multi_rank_ema( + cls, model: torch.nn.Module, rate: float, num: int, enabled: bool = True + ) -> Optional[PowerEMATracker]: + """ + Class method to initialize per rank EMA Model Tracker with different rate. + Each rank will have a different rate based on the given configuration, resulting in different EMA weights. + + Args: + model (torch.nn.Module): The neural network model for which the EMA tracker is being set up. + num (int): The number of ranks for which the rate adjustment is applied. Beyond this, the rate remains unchanged. + rate (float): The base decay rate for the EMA calculation. + enabled (bool, optional): Flag to enable or disable the initialization of the tracker. If False, returns None. + Defaults to True. + + Returns: + Optional[PowerEMATracker]: An instance of PowerEMATracker with adjusted rate if enabled, otherwise None. + + Raises: + None + + Example: + >>> model = torch.nn.Linear(10, 2) + >>> tracker = PowerEMATracker.initialize_multi_rank_ema(model, num=3, rate=0.99) + >>> print(tracker) + + Notes: + The decay rate is modified by dividing it by 2 raised to the power of the rank for each rank less than `num`. + If the rank is greater than or equal to `num`, the base rate is used without modification. This approach + allows higher ranked processes to have a less aggressive decay, potentially reflecting their delayed synchronization + in a distributed training scenario. + """ + if not enabled: + return None + if parallel_state.is_initialized(): + cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + log.warning("It should not used together with FSDP!") + else: + cur_dp_rank = distributed.get_rank() + log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) + + divider = 2**cur_dp_rank if cur_dp_rank < num else 1 + if cur_dp_rank < num: + print(f"PowerEMATracker: rank {cur_dp_rank}, rate {rate / divider}") + return cls(model, rate / divider) + + +@contextmanager +def ema_scope(model: Model, enabled: bool = False) -> Generator[None, None, None]: + """Context manager for switching between regular and EMA model weights. + + Args: + model (Model): The PyTorch model. + enabled (bool): Whether switching to EMA weights is enabled (default: False). + """ + if enabled: + assert hasattr(model, "ema") and isinstance(model.ema, (FastEmaModelUpdater, EMAModelTracker, PowerEMATracker)) + model.ema.cache(model.parameters()) + model.ema.copy_to(model) + log.info("EMA: switched to EMA weights.") + try: + yield None + finally: + if enabled: + model.ema.restore(model.parameters()) + log.info("EMA: restored regular weights.") diff --git a/cosmos_predict1/utils/env_parsers/cred_env_parser.py b/cosmos_predict1/utils/env_parsers/cred_env_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa34d8402cfd80d86bed87be31863b3afd0ceaa --- /dev/null +++ b/cosmos_predict1/utils/env_parsers/cred_env_parser.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict1.utils.env_parsers.env_parser import EnvParser +from cosmos_predict1.utils.validator import String + + +class CredentialEnvParser(EnvParser): + APP_ENV = String(default="") + PROD_FT_AWS_CREDS_ACCESS_KEY_ID = String(default="") + PROD_FT_AWS_CREDS_SECRET_ACCESS_KEY = String(default="") + PROD_FT_AWS_CREDS_ENDPOINT_URL = String(default="https://s3.us-west-2.amazonaws.com") + PROD_FT_AWS_CREDS_REGION_NAME = String(default="us-west-2") + + PROD_S3_CHECKPOINT_ACCESS_KEY_ID = String(default="") + PROD_S3_CHECKPOINT_SECRET_ACCESS_KEY = String(default="") + PROD_S3_CHECKPOINT_ENDPOINT_URL = String(default="") + PROD_S3_CHECKPOINT_REGION_NAME = String(default="") + + PROD_TEAM_DIR_ACCESS_KEY_ID = String(default="") + PROD_TEAM_DIR_SECRET_ACCESS_KEY = String(default="") + PROD_TEAM_DIR_ENDPOINT_URL = String(default="") + PROD_TEAM_DIR_REGION_NAME = String(default="") + + PICASSO_AUTH_MODEL_REGISTRY_API_KEY = String(default="") + PICASSO_API_ENDPOINT_URL = String(default="https://meeocvslt2.execute-api.us-west-2.amazonaws.com") + + +CRED_ENVS = CredentialEnvParser() +CRED_ENVS_DICT = { + "PROD_FT_AWS_CREDS": { + "aws_access_key_id": CRED_ENVS.PROD_FT_AWS_CREDS_ACCESS_KEY_ID, + "aws_secret_access_key": CRED_ENVS.PROD_FT_AWS_CREDS_SECRET_ACCESS_KEY, + "endpoint_url": CRED_ENVS.PROD_FT_AWS_CREDS_ENDPOINT_URL, + "region_name": CRED_ENVS.PROD_FT_AWS_CREDS_REGION_NAME, + }, + "PROD_S3_CHECKPOINT": { + "aws_access_key_id": CRED_ENVS.PROD_S3_CHECKPOINT_ACCESS_KEY_ID, + "aws_secret_access_key": CRED_ENVS.PROD_S3_CHECKPOINT_SECRET_ACCESS_KEY, + "endpoint_url": CRED_ENVS.PROD_S3_CHECKPOINT_ENDPOINT_URL, + "region_name": CRED_ENVS.PROD_S3_CHECKPOINT_REGION_NAME, + }, + "PROD_TEAM_DIR": { + "aws_access_key_id": CRED_ENVS.PROD_TEAM_DIR_ACCESS_KEY_ID, + "aws_secret_access_key": CRED_ENVS.PROD_TEAM_DIR_SECRET_ACCESS_KEY, + "endpoint_url": CRED_ENVS.PROD_TEAM_DIR_ENDPOINT_URL, + "region_name": CRED_ENVS.PROD_TEAM_DIR_REGION_NAME, + }, +} diff --git a/cosmos_predict1/utils/env_parsers/env_parser.py b/cosmos_predict1/utils/env_parsers/env_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..1579cbddb1edd8eff1209aa48e7b5cd674f71e2a --- /dev/null +++ b/cosmos_predict1/utils/env_parsers/env_parser.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import os + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.validator import JsonDict, Validator + +""" +Base class for parsing environment variables using validators. +Class will go through its list of validators and retrieve values from same named environment variables. +Validators provide: +- default value +- typed parsing +- enforments of mandatory values + +Additionally the environment variables can be passed as single base64 encoded string. + +we cannot enforce that a component isn't directly using the environment variables. +so evaluation of params should throw error to make sure actual env var is correct. +""" + + +class EnvParser: + def __init__(self, b64_str=None): + if b64_str: + log.critical(f"b64_str recieved: {b64_str}") + self.from_b64(b64_str) + else: + self.from_env() + + def from_env(self): + validators = self.get_val_dict() + for key in validators.keys(): + val = os.getenv(key.upper()) + log.debug(f"getting env var {key.upper()}: {val}") + if val: + setattr(self, key, val) + self.check_mandatory_values() + + def from_json(self, file_name): + with open(file_name, "r") as f: + log.info(f"Reading env params from {file_name}") + dict = json.load(f) + for key, value in dict.items(): + setattr(self, key, value) + self.check_mandatory_values() + + def to_b64(self): + json_str = self.to_json() + # create bytes-like object for b64 encoder + json_str_bytes = json_str.encode() + b64_str = base64.b64encode(json_str_bytes).decode() + + print(b64_str) + return b64_str + + def from_b64(self, b64_str): + json_str = base64.b64decode(b64_str).decode() + dict = json.loads(json_str) + for key, value in dict.items(): + setattr(self, key, value) + self.check_mandatory_values() + + def check_mandatory_values(self): + for key, validator in self.get_val_dict().items(): + if getattr(self, key) is None and validator.default is None: + raise ValueError(f"Missing mandatory env var: {key}") + + @classmethod + def get_val_dict(cls): + log.debug(f"getting val dict of {cls.__name__}") + val_dict = {} + val_dict.update({key: value for key, value in cls.__dict__.items() if isinstance(value, Validator)}) + + return val_dict + + def dump_validators(self): + validators = self.get_val_dict() + for key, value in validators.items(): + log.debug(f"{key}: {value.__get__(self)}") + + def to_json(self, file_name=None): + dict = { + key.upper(): value.__get__(self) + for key, value in EnvParser.__dict__.items() + if isinstance(value, Validator) + } + json_str = json.dumps(dict, indent=4) + print(json_str) + + if file_name: + with open(file_name, "w") as f: + log.info(f"Writing env params to {file_name}") + f.write(json_str) + + return json_str + + def to_string_dict(self): + result = {} + for key, validator in self.get_val_dict().items(): + value = getattr(self, key) + if value is None: + value = validator.default + if isinstance(validator, JsonDict): + value = json.dumps(value) + else: + value = str(value) + result[key] = value + return result + + def __str__(self): + return ", ".join(f"{key}={value}" for key, value in self.__dict__.items()) diff --git a/cosmos_predict1/utils/fsdp_checkpointer.py b/cosmos_predict1/utils/fsdp_checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba4d670f8f7083d19737dc013bf59715205d266 --- /dev/null +++ b/cosmos_predict1/utils/fsdp_checkpointer.py @@ -0,0 +1,411 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import gc +import os +import threading + +import torch +from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType + +from cosmos_predict1.utils import callback, distributed, log, misc +from cosmos_predict1.utils.config import CheckpointConfig, JobConfig +from cosmos_predict1.utils.easy_io import easy_io +from cosmos_predict1.utils.fsdp_optim_fix import scatter_full_optim_state_dict +from cosmos_predict1.utils.model import Model + + +class FSDPCheckpointer: + """The checkpointer class. Supports checkpoint saving/loading to local disk.""" + + def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): + """Constructor of the checkpointer. + + Args: + config_checkpoint (CheckpointConfig): The config object for the checkpointer. + """ + # Set the callback functions. + self.callbacks = callbacks + self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" + self.strict_resume = config_checkpoint.strict_resume + self.load_path = config_checkpoint.load_path + self.load_training_state = config_checkpoint.load_training_state + self.save_thread = None + self.config_checkpoint = config_checkpoint + + def _load_ckpt_file_during_init(self): + latest_checkpoint_file = self._read_latest_checkpoint_file() + if latest_checkpoint_file is not None: + # 1. Resume training from latest_checkpoint.txt under the same name. + checkpoint_dir = self.checkpoint_dir_local + checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) + resume = True + log.critical(f"[Checkpoint] Found latest checkpoint file: {latest_checkpoint_file}") + log.critical(f"[Checkpoint] Loading from local path: {checkpoint_path}") + log.critical("[Checkpoint] Will resume full training state (model, optimizer, scheduler)") + else: + if self.load_path: + # 2. Load the module weights specified by config_checkpoint.path. + checkpoint_path = self.load_path + resume = self.load_training_state + log.critical(f"[Checkpoint] Using specified checkpoint path: {checkpoint_path}") + if resume: + log.critical("[Checkpoint] Will load complete training state (model, optimizer, scheduler)") + else: + log.critical("[Checkpoint] Will load model weights only (no optimizer/scheduler state)") + else: + # 3. Randomly initialize the model parameters and train from scratch. + checkpoint_path = None + resume = False + log.critical("[Checkpoint] No checkpoint path specified") + log.critical("[Checkpoint] Starting fresh training with random initialization") + return checkpoint_path, resume + + @misc.timer("FSDP.load_model_during_init") + def load_model_during_init(self, model, is_ema=False, ema_id: int = 0): + if ema_id > 0: + assert is_ema, "ema_id should be used with is_ema=True" + checkpoint_path, _ = self._load_ckpt_file_during_init() + if checkpoint_path is not None: + tag = "reg" if not is_ema else "ema" + default_checkpoint_path = checkpoint_path.replace(".pt", f"_{tag}_model.pt") + if not os.path.exists(default_checkpoint_path): + default_checkpoint_path = checkpoint_path # starting from the release checkpoint + log.warning(f"is_ema={is_ema} model is not found. Loading from {default_checkpoint_path}") + if tag == "ema" and ema_id > 0: + _checkpoint_path = checkpoint_path.replace(".pt", f"_RANK{ema_id}.pt") + _checkpoint_path = _checkpoint_path.replace(".pt", f"_{tag}_model.pt") + if self._check_checkpoint_exists(_checkpoint_path, is_raise=False): + default_checkpoint_path = _checkpoint_path + else: + print( + f"{distributed.get_rank()}: Checkpoint not found: {_checkpoint_path} " + f"(fallback to {default_checkpoint_path})" + ) + checkpoint_path = default_checkpoint_path + self._check_checkpoint_exists(checkpoint_path) + + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + log.info("- Loading the model...") + if self.strict_resume: + log.info(model.load_state_dict(state_dict, strict=self.strict_resume)) + else: + log.critical("\t Using non-strict model") + from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model + + log.info(non_strict_load_model(model, state_dict)) + log.info("-finish model loading") + else: + log.info(f"is_ema={is_ema} model is not found and loaded.") + + @misc.timer("FSDP.load_optim_scheduler_during_init") + def load_optim_scheduler_during_init(self, fsdp_model, optimizer, scheduler): + checkpoint_path, resume = self._load_ckpt_file_during_init() + log.critical(f"Loading optimizer and scheduler: {checkpoint_path} (resume: {resume}") + if checkpoint_path is not None: + if resume: + checkpoint_path = checkpoint_path.replace(".pt", "_optim.pt") + self._check_checkpoint_exists(checkpoint_path) + if distributed.get_rank() == 0: + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load( + checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False + ) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + log.info("- Loading the optimizer (FSDP scatter)...") + else: + state_dict = { + "optimizer": None, + "scheduler": None, + } + distributed.barrier() + sharded_optimizer_state_dict = scatter_full_optim_state_dict( # <---- FSDP + state_dict["optimizer"], + fsdp_model, + ) + log.info("- Loading the optimizer (FSDP load_state_dict)...") + log.info(optimizer.load_state_dict(sharded_optimizer_state_dict)) + log.critical("Skip loading the scheduler...") + return + log.info("- Loading the scheduler...") + scheduler.load_state_dict(state_dict["scheduler"]) + + @misc.timer("FSDP get_optim_scheduler_state") + def get_optim_scheduler_state(self, optim, fsdp_model, scheduler): + with FSDP.state_dict_type( + fsdp_model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim) + scheduler_statedict = scheduler.state_dict() + return { + "optimizer": optim_statedict, + "scheduler": scheduler_statedict, + } + + def save( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + iteration: int, + async_saving: bool = True, + ) -> None: + """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. + + Args: + model (Model): The PyTorch model. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + iteration (int): Current iteration number. + """ + self.callbacks.on_save_checkpoint_start(model, iteration) + + model_state_dict = model.state_dict_model() + optim_scheduler_state_dict = self.get_optim_scheduler_state(optimizer, model.model, scheduler) + torch.cuda.empty_cache() + state_dict = dict( + iteration=iteration, + ) + self.callbacks.on_save_checkpoint(model, state_dict=state_dict) + + postfix, replicate_idx, shard_idx, total_ema_num = model.get_ckpt_postfix() + if replicate_idx == 0 and shard_idx == 0: + pass # save whole; it is rank0 + elif replicate_idx < total_ema_num and shard_idx == 0: + model_state_dict["model"] = None # only save ema + optim_scheduler_state_dict = None + state_dict = None + else: + return + + checkpoint_file = f"iter_{iteration:09}{postfix}.pt" + if async_saving: + # Wait for previous saver thread to end. + if self.save_thread: + self.save_thread.join() + # Run the checkpoint saver in a separate thread. + self.save_thread = threading.Thread( + target=self._save_worker_local, + daemon=False, + args=( + model_state_dict, + optim_scheduler_state_dict, + state_dict, + checkpoint_file, + distributed.get_rank(), + ), + ) + self.save_thread.start() + log.info("checkpoint saving from an async thread") + else: + torch.cuda.empty_cache() + # Run the checkpoint saver in the current thread. + self._save_worker_local( + model_state_dict, optim_scheduler_state_dict, state_dict, checkpoint_file, distributed.get_rank() + ) + log.info("checkpoint saved within the main thread") + del model_state_dict, optim_scheduler_state_dict, state_dict + gc.collect() + torch.cuda.empty_cache() + self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) + + @misc.timer("checkpoint saving (local)") + def _save_worker_local( + self, + model_state_dict: dict[str, torch.Tensor], + optim_scheduler_state_dict: dict[str, torch.Tensor], + state_dict: dict[str, torch.Tensor], + checkpoint_file: str, + rank: int = 0, + ) -> None: + """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). + + Args: + state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. + checkpoint_file (str): The file name of the model checkpoint. + rank (int): GPU device (default: 0). + """ + checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) + os.makedirs(self.checkpoint_dir_local, exist_ok=True) + try: + model_state_dict, ema_model_state_dict = model_state_dict["model"], model_state_dict["ema"] + if model_state_dict is not None: + torch.save(model_state_dict, checkpoint_path.replace(".pt", "_reg_model.pt")) + if ema_model_state_dict is not None: + torch.save(ema_model_state_dict, checkpoint_path.replace(".pt", "_ema_model.pt")) + if optim_scheduler_state_dict is not None: + torch.save(optim_scheduler_state_dict, checkpoint_path.replace(".pt", "_optim.pt")) + if state_dict is not None: + torch.save(state_dict, checkpoint_path) + if rank == 0: + self._write_latest_checkpoint_file(checkpoint_file) + log.success(f"Saved checkpoint (local): {checkpoint_path}") + iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) + self.callbacks.on_save_checkpoint_success(iteration=iteration) + except Exception as e: # noqa: BLE001 + log.exception(f"Checkpoint failed to save (local): {e}") + + @misc.timer("checkpoint loading") + def load( + self, + model: Model, + optimizer: torch.optim.Optimizer | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> int: + """Load network weights and optimizer states from a checkpoint in a single process. + + The priority of the checkpoint loading logic is: + 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. + 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. + - This is typically used for inference mode. + - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. + 3. If none of the above, randomly initialize the model parameters and train from scratch. + + Args: + model (FSDPDiffModle): The PyTorch model. + optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). + scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). + grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). + + Returns: + iteration (int): the iteration number to start/resume from. + """ + self.callbacks.on_load_checkpoint_start(model) + + del optimizer, grad_scaler + checkpoint_path, resume = self._load_ckpt_file_during_init() + iteration = 0 + if checkpoint_path is not None: + self._check_checkpoint_exists(checkpoint_path) + log.info(f"Loading checkpoint (local): {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) + log.success(f"Complete loading checkpoint (local): {checkpoint_path}") + self.callbacks.on_load_checkpoint(model, state_dict=state_dict) + if resume: + iteration = state_dict["iteration"] + log.success("Done with loading the checkpoint.") + else: + log.info("Training from scratch.") + torch.cuda.empty_cache() + + self.callbacks.on_load_checkpoint_end(model) + + if scheduler is not None: + scheduler.last_epoch = iteration + log.critical(f"resume scheduler from {iteration}", rank0_only=False) + + return iteration + + def _read_latest_checkpoint_file(self) -> str | None: + """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. + + Returns: + checkpoint_file (str | None): file name of the latest saved checkpoint. + """ + checkpoint_file = None + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + if os.path.isfile(latest_path): + checkpoint_file = open(latest_path).read().strip() + if checkpoint_file is None: + log.warning(f"Latest ckpt file not found: {latest_path}") + else: + log.info(f"Found latest checkpoint: {checkpoint_file}") + return checkpoint_file + + def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: + """Track the file name of the latest saved checkpoint. + + Args: + checkpoint_file (str): file name of the latest saved checkpoint. + """ + content = f"{checkpoint_file}\n" + latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") + with open(latest_path, "w") as file: + file.write(content) + + def _check_checkpoint_exists(self, checkpoint_path: str, is_raise: bool = True) -> None: + """If the file checkpoint_path does not exist, raise an error. + + Args: + checkpoint_path (str): full path to the checkpoint. + """ + if not os.path.exists(checkpoint_path): + if is_raise: + raise FileNotFoundError(f"File not found (local): {checkpoint_path}") + return False + return True + + def finalize(self) -> None: + """Finalize the checkpointer.""" + if self.save_thread: + self.save_thread.join() + + +class FSDPInferenceCheckpointer: + def __init__( + self, + ckpt_path: str, + strict_resume: bool = True, + ): + self.ckpt_path = ckpt_path + self.strict_resume = strict_resume + + @misc.timer("FSDPInferenceCheckpointer.load_model_during_init") + def load_model_during_init(self, model, is_ema=False, ema_id: int = 0): + del ema_id + if is_ema: + log.warning("EMA model is not supported in inference mode.") + return + assert easy_io.exists(self.ckpt_path) + log.info(f"Loading from {self.ckpt_path}") + state_dict = torch.load(self.ckpt_path, map_location=lambda storage, loc: storage, weights_only=False) + if self.strict_resume: + log.info(model.load_state_dict(state_dict, strict=self.strict_resume)) + else: + log.critical("\t Using non-strict model") + from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model + + log.info(non_strict_load_model(model, state_dict)) + log.info("-finish model loading") + + def load_optim_scheduler_during_init(self, *args, **kwargs): + """ + We do not do load in inference mode. The function is here to maintain the same interface to avoid errors. + """ + pass + + def save(self, *args, **kwargs): + """ + We do not save anything in inference mode. The function is here to maintain the same interface to avoid errors. + """ + pass + + def load(self, *args, **kwargs): + """ + We do not do load in inference mode. The function is here to maintain the same interface to avoid errors. + """ + return 0 diff --git a/cosmos_predict1/utils/fsdp_optim_fix.py b/cosmos_predict1/utils/fsdp_optim_fix.py new file mode 100644 index 0000000000000000000000000000000000000000..a08aa943828d4c9c385d25873528edae0a84ec24 --- /dev/null +++ b/cosmos_predict1/utils/fsdp_optim_fix.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +# isort: skip_file + +""" +torch 2.2 has bugs in loading optimizer states for FSDP in hybrid mode +torch impl uses state.rank and dist.rank() inconsistently +The file fix the bugs. Verified it works for hybrid mode and fullly sharded mode +Please use the `scatter_full_optim_state_dict` in the code to replace the corresponding function in torch 2.2 +""" + +import copy +import warnings +from typing import Any, Dict, Iterable, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.fsdp._debug_utils import SimpleProfiler +from torch.distributed.fsdp._optim_utils import ( + _flatten_optim_state, + _FSDPState, + _get_fqn_to_fsdp_param_info, + _get_param_to_fqns, + _OptimStateKey, + _PosDimTensorInfo, + _shard_orig_param_state, + tree_map_only, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import _rekey_sharded_optim_state_dict + + +def _broadcast_processed_state( + fsdp_state: _FSDPState, + optim_state: Dict[str, Any], + group: Optional[dist.ProcessGroup], +) -> Dict[str, Any]: + objects: List[Any] = [None] + if fsdp_state.rank == 0: + objects[0] = tree_map_only( + torch.Tensor, + lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), + optim_state, + ) + dist.broadcast_object_list(objects, src=0, group=group) + if dist.get_rank() == 0: + return optim_state + else: + return objects[0] + + +def _broadcast_state(fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]) -> Any: + if dist.get_rank() == 0: + if not isinstance(state, torch.Tensor) or state.dim() == 0: + return state + tensor = state.to(fsdp_state.compute_device) + else: + if isinstance(state, torch.Tensor): + assert state.dim() == 0, ( + "For non-zero ranks, a tensor state should have zero dimension, " + "but got the state with shape {state.shape()}." + ) + return state + elif not isinstance(state, _PosDimTensorInfo): + return state + tensor = torch.zeros(state.shape, dtype=state.dtype, device=fsdp_state.compute_device) + dist.broadcast(tensor, src=0, group=group) + return tensor + + +def _flatten_optim_state_dict( + optim_state_dict: Dict[str, Any], + model: nn.Module, + use_orig_params: bool = False, + optim: Optional[torch.optim.Optimizer] = None, + rank0_only: bool = False, + group: Optional[dist.ProcessGroup] = None, +) -> Dict[str, Any]: + """ + Flattens the full optimizer state dict, still keying by unflattened parameter + names. + + If ``use_orig_params`` is True, each rank will have all FSDP-managed + parameters but some of these parameters may be empty due to the sharding. + For a regular optim.Optimizer, states for those empty parameters will + not be initialized. So, when aggregating the FQNs across ranks, no assert + will be raised on a rank even if it does not have all the states -- it is + valid and FSDP know how to aggregate them. However, FSDP has to ignore + handling those parameters that are not managed by FSDP and do not exist on + the local rank -- it is managed by other parallelism and FSDP does not + know ho to handle/aggregate them. + + Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to + flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require + all the states even if the corresponding parameters are empty. To this end, + ``optim`` will be used to to get the initial state of the empty parameters. + ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or + NamedOptimizer. + + Returns: + Dict[str, Any]: The flattened optimizer state dict. + """ + SimpleProfiler.reset() + + unflat_osd = optim_state_dict + if "state" not in unflat_osd and not rank0_only: + raise ValueError('`optim_state_dict` must have the keys "state"' "to be a valid optimizer state dict") + param_to_fqns = _get_param_to_fqns(model) + fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) + fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state + + # Broadcast unflat_osd without non-scalar tensor if rank0_only is True. + if rank0_only: + unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group) + + # Construct the "state" part + flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {} + unflat_osd_state = unflat_osd["state"] + all_state_keys = set(unflat_osd_state.keys()) + + for param, fqns in param_to_fqns.items(): + fqn = fqns[0] + if fqn not in unflat_osd_state: + continue + all_state_keys.difference_update(fqns) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name in unflat_osd_state[fqn].keys(): + unflat_osd_state[fqn][state_name] = _broadcast_state( + fsdp_state, unflat_osd_state[fqn][state_name], group=group + ) + fqn = fqns[0] + if fqn in fqn_to_fsdp_param_info: + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + if use_orig_params: + with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): + flat_state = _shard_orig_param_state( + fsdp_param_info, + fqn, + unflat_osd_state[fqn], + ) + else: + flat_state = _flatten_optim_state( + fsdp_param_info, + unflat_osd_state, + fqns, + ) + key = _OptimStateKey(tuple(fqns), True) + # Only include non-empty states since as expected by + # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer + # or NamedOptimizer. + if flat_state: + flat_osd_state[key] = flat_state + elif use_orig_params: + assert len(fqns) == 1, f"use_orig_params is True but there are multiple FQNs, {fqns}." + if optim is not None: # NamedOptimizer or KeyedOptimizer case. + state = optim.state.get(param, None) # type: ignore[call-overload] + if state is not None: + flat_osd_state[key] = copy.deepcopy(state) + else: + warnings.warn(f"optim_state[{key}] is not on rank{fsdp_state.rank}.") + + else: + raise RuntimeError(f"The state of {key} is empty. This should happen when " "use_orig_params=True.") + else: # do not flatten non-FSDP parameters' states + assert len(fqns) == 1 + key = _OptimStateKey(tuple(fqns), False) + flat_osd_state[key] = copy.copy(unflat_osd_state[fqn]) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name, param_state in list(unflat_osd_state[fqn].items()): + if fsdp_state.rank > 0: + # Deference the tensor so that PyTorch can collect the memory. + del unflat_osd_state[fqn][state_name] + else: + # Move the tensor in the original osd back to CPU to make the + # original osd unaffected. + unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][state_name].cpu() + + # Handle user-defined state, states that are not associated with parameters. + for key in all_state_keys: + user_state = unflat_osd_state[key] + if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params: + user_state = _broadcast_state(fsdp_state, user_state, group=group) + flat_osd_state[key] = copy.copy(user_state) + + SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ") + # Construct the "param_groups" part -- copy as is since it will be + # rekeyed later according to the target rank's optimizer + # Only copy param_groups if it exists in unflat_osd + if "param_groups" in unflat_osd: + flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"]) + return {"state": flat_osd_state, "param_groups": flat_osd_param_groups} + else: + return {"state": flat_osd_state} + + +def _optim_state_dict_to_load_impl( + optim_state_dict: Dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + full_state_dict: bool = True, + rank0_only: bool = False, + is_named_optimizer: bool = False, + group: Optional[dist.ProcessGroup] = None, +) -> Dict[str, Any]: + """ + The internal API that is used by all the load optim_state_dict implementations. + Given model, optim, and the saved optim_state_dict, this API adds the FSDP + internal information and internal sharding to the optim_state_dict. + """ + if full_state_dict: + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + else: + using_optim_input = False + assert optim_input is None and not rank0_only + + use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[0]._use_orig_params + assert all( + use_orig_params == m._use_orig_params for m in FullyShardedDataParallel.fsdp_modules(model) + ), "Not all FSDP modules have the same _use_orig_params value" + + if rank0_only and dist.get_rank(group) > 0: + optim_state_dict = {} + sharded_osd = _flatten_optim_state_dict( + optim_state_dict, + model=model, + use_orig_params=use_orig_params, + optim=(optim if is_named_optimizer else None), + rank0_only=rank0_only, + group=group, + ) + return _rekey_sharded_optim_state_dict( + sharded_osd, + model=model, + optim=optim, + optim_input=optim_input, + using_optim_input=using_optim_input, + is_named_optimizer=is_named_optimizer, + ) + + +def scatter_full_optim_state_dict( + full_optim_state_dict: Optional[Dict[str, Any]], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + List[Dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + group: Optional[Any] = None, +) -> Dict[str, Any]: + """ + Scatters the full optimizer state dict from rank 0 to all other ranks, + returning the sharded optimizer state dict on each rank. The return + value is the same as :meth:`shard_full_optim_state_dict`, and on rank + 0, the first argument should be the return value of + :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 + >>> # Define new model with possibly different world size + >>> new_model, new_optim, new_group = ... + >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to GPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state + dict corresponding to the unflattened parameters and holding + the full non-sharded optimizer state if on rank 0; the argument + is ignored on nonzero ranks. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + group (dist.ProcessGroup): Model's process group or ``None`` if + using the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict("scatter_full_optim_state_dict", "optim_state_dict_to_load") + return _optim_state_dict_to_load_impl( + optim_state_dict=full_optim_state_dict, + model=model, + optim_input=optim_input, + optim=optim, + full_state_dict=True, + rank0_only=True, + is_named_optimizer=False, + group=group, + ) diff --git a/cosmos_predict1/utils/fused_adam.py b/cosmos_predict1/utils/fused_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..192e29552f8fcaec5b50f325d35c32b6807948f1 --- /dev/null +++ b/cosmos_predict1/utils/fused_adam.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from apex.multi_tensor_apply import multi_tensor_applier + +from cosmos_predict1.utils import distributed, log + + +class FusedAdam(torch.optim.Optimizer): + """Implements Adam algorithm. + + Currently GPU-only. Requires Apex to be installed via + ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. + + This version of fused Adam implements 2 fusions. + + * Fusion of the Adam update's elementwise operations + * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters + into one or a few kernel launches. + + :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adam_w_mode=False``:: + + opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) + ... + opt.step() + + :class:`apex.optimizers.FusedAdam` may be used with or without Amp. If you wish to use :class:`FusedAdam` with Amp, + you may choose any ``opt_level``:: + + opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) + model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") + ... + opt.step() + + In general, ``opt_level="O1"`` is recommended. + + + .. warning:: + A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``. + These additional arguments are now deprecated and unnecessary. + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in FusedAdam! + adam_w_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + capturable (bool, optional): whether to use the version of the optimizer + that can be used with CUDA Graphs. (default: False) + master_weights (bool, optional): whether to maintain FP32 master weights + in the optimizer with FP16 mixed precision training, currently can + only be used with capturable set to True. (default: False) + + .. _Adam - A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + adam_w_mode=True, + weight_decay=0.0, + amsgrad=False, + capturable=False, + master_weights=False, + ): + if amsgrad: + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") + if master_weights and not capturable: + raise RuntimeError("Master weights is currently only supported with the capturable version.") + # If the optimizer is capturable then LR should be a tensor (on GPU) + log.warning(f"FusedAdam master_weights: {master_weights} capturable: {capturable}") + lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr + defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) + super(FusedAdam, self).__init__(params, defaults) + self.adam_w_mode = 1 if adam_w_mode else 0 + + self.capturable = capturable + self.master_weights = master_weights + + self.param_groups_master = None + + if capturable: + for idx, group in enumerate(self.param_groups): + if len(group["params"]) == 0: + continue + device = group["params"][0].device + for item in ["lr"]: + if isinstance(group[item], float): + group[item] = torch.tensor(group[item], dtype=torch.float32) + self.param_groups[idx][item] = group[item].to(device=device) + + self._step_supports_amp_scaling = True + + if multi_tensor_applier.available: + import amp_C + + # Skip buffer + self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + self.multi_tensor_adam = amp_C.multi_tensor_adam + self.multi_tensor_adam_capturable = amp_C.multi_tensor_adam_capturable + self.multi_tensor_adam_capturable_master = amp_C.multi_tensor_adam_capturable_master + else: + raise RuntimeError("apex.optimizers.FusedAdam requires cuda extensions") + + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError( + "FusedAdam has been updated. " + "Simply initialize it identically to torch.optim.Adam, and call step() with no arguments." + ) + loss = None + if closure is not None: + loss = closure() + + if self.param_groups_master is None: + # Create full precision master weights + self.param_groups_master = [] + for i, pg in enumerate(self.param_groups): + param_list = pg["params"] + self.param_groups_master.append( + { + "params": [p.clone().detach().float() if self.master_weights else None for p in param_list], + } + ) + + for group, group_master in zip(self.param_groups, self.param_groups_master): + if len(group["params"]) == 0: + continue + device = group["params"][0].device + bias_correction = 1 if "bias_correction" in group and group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if "step" in group: + if self.capturable: + group["step"] = ( + group["step"].to(device=device) + if isinstance(group["step"], torch.Tensor) + else torch.tensor(group["step"], dtype=torch.int32, device=device) + ) + group["step"] += (self._dummy_overflow_buf != 1).to(torch.int) + else: + group["step"] += 1 + else: + group["step"] = 1 if not self.capturable else torch.tensor([1], dtype=torch.int, device=device) + + if self.capturable: + group["lr"] = ( + group["lr"].to(device=device) + if isinstance(group["lr"], torch.Tensor) + else torch.tensor(group["lr"], dtype=torch.float32, device=device) + ) + + # create lists for multi-tensor apply + g_16, p_16, m_16, v_16 = [], [], [], [] + g_bf, p_bf, m_bf, v_bf = [], [], [], [] + g_32, p_32, m_32, v_32 = [], [], [], [] + p_16_master = [] + p_32_master = [] + bf16_master = [] + + for p, p_master in zip(group["params"], group_master["params"]): + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError( + "FusedAdam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data).float() + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data).float() + + if p.dtype == torch.float16: + if self.master_weights: + p_16_master.append(p_master.data) + g_16.append(p.grad.data) + p_16.append(p.data) + m_16.append(state["exp_avg"]) + v_16.append(state["exp_avg_sq"]) + elif p.dtype == torch.bfloat16: + if self.master_weights: + bf16_master.append(p_master.data) + g_bf.append(p.grad) + p_bf.append(p) + m_bf.append(state["exp_avg"]) + v_bf.append(state["exp_avg_sq"]) + elif p.dtype == torch.float32: + if self.master_weights: + p_32_master.append(p_master.data) + g_32.append(p.grad.data) + p_32.append(p.data) + m_32.append(state["exp_avg"]) + v_32.append(state["exp_avg_sq"]) + else: + raise RuntimeError("FusedAdam only support fp16 and fp32.") + + # If the optimizer is capturable, then if there's a grad scaler it works + # on the GPU + a different multi_tensor_applier should be called + if self.capturable: + # overflow check of gradients + found_inf = ( + grad_scaler._check_inf_per_device(self)[device] + if grad_scaler is not None + else torch.zeros((1,), device=device) + ) + self._dummy_overflow_buf.copy_(found_inf) + + # get unscale scale factor + scale, inv_scale = None, None + if grad_scaler: + scale = grad_scaler._get_scale_async() + inv_scale = scale.double().reciprocal().float() + else: + scale = torch.ones((1,), device=device, dtype=torch.float32) + inv_scale = torch.ones((1,), device=device, dtype=torch.float32) + + if len(g_16) > 0: + multi_tensor_applier( + ( + self.multi_tensor_adam_capturable_master + if self.master_weights + else self.multi_tensor_adam_capturable + ), + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16, p_16_master] if self.master_weights else [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + inv_scale, + ) + + if len(g_bf) > 0: + multi_tensor_applier( + ( + self.multi_tensor_adam_capturable_master + if self.master_weights + else self.multi_tensor_adam_capturable + ), + self._dummy_overflow_buf, + [g_bf, p_bf, m_bf, v_bf, bf16_master] if self.master_weights else [g_bf, p_bf, m_bf, v_bf], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + inv_scale, + ) + + if len(g_32) > 0: + multi_tensor_applier( + ( + self.multi_tensor_adam_capturable_master + if self.master_weights + else self.multi_tensor_adam_capturable + ), + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32, p_32_master] if self.master_weights else [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + inv_scale, + ) + else: + if len(g_16) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + if len(g_bf) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_bf, p_bf, m_bf, v_bf], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + if len(g_32) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + return loss + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + for group in self.param_groups: + if self.capturable: + group["lr"] = ( + group["lr"].cuda() + if isinstance(group["lr"], torch.Tensor) + else torch.tensor(group["lr"], dtype=torch.float32).cuda() + ) + + if "step" in group: + if self.capturable: + if distributed.get_rank() == 0: + step = ( + group["step"].cuda() + if isinstance(group["step"], torch.Tensor) + else torch.tensor([group["step"]], dtype=torch.int32).cuda() + ) + else: + step = torch.zeros(1, dtype=torch.int32).cuda() + # make it compatible with FSDP optimizer + distributed.broadcast(step, 0) + group["step"] = step + elif isinstance(group["step"], torch.Tensor): + group["step"] = group["step"].item() + for p in group["params"]: + state = self.state[p] + if "exp_avg" in state: + state["exp_avg"] = state["exp_avg"].float() + state["exp_avg_sq"] = state["exp_avg_sq"].float() diff --git a/cosmos_predict1/utils/io.py b/cosmos_predict1/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..c877aa41fd6b90638281f048bac23fc8214b84be --- /dev/null +++ b/cosmos_predict1/utils/io.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from io import BytesIO +from typing import Dict, List + +import imageio +import numpy as np + + +def read_prompts_from_file(prompt_file: str) -> List[Dict[str, str]]: + """Read prompts from a JSONL file where each line is a dict with 'prompt' key and optionally 'visual_input' key. + + Args: + prompt_file (str): Path to JSONL file containing prompts + + Returns: + List[Dict[str, str]]: List of prompt dictionaries + """ + prompts = [] + with open(prompt_file, "r") as f: + for line in f: + prompt_dict = json.loads(line.strip()) + prompts.append(prompt_dict) + return prompts + + +def save_video(video, fps, H, W, video_save_quality, video_save_path): + """Save video frames to file. + + Args: + grid (np.ndarray): Video frames array [T,H,W,C] + fps (int): Frames per second + H (int): Frame height + W (int): Frame width + video_save_quality (int): Video encoding quality (0-10) + video_save_path (str): Output video file path + """ + kwargs = { + "fps": fps, + "quality": video_save_quality, + "macro_block_size": 1, + "ffmpeg_params": ["-s", f"{W}x{H}"], + "output_params": ["-f", "mp4"], + } + imageio.mimsave(video_save_path, video, "mp4", **kwargs) + + +def load_from_fileobj(filepath: str, format: str = "mp4", mode: str = "rgb", **kwargs): + """ + Load video from a file-like object using imageio with specified format and color mode. + + Parameters: + file (IO[bytes]): A file-like object containing video data. + format (str): Format of the video file (default 'mp4'). + mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb'). + + Returns: + tuple: A tuple containing an array of video frames and metadata about the video. + """ + with open(filepath, "rb") as f: + value = f.read() + with BytesIO(value) as f: + f.seek(0) + video_reader = imageio.get_reader(f, format, **kwargs) + + video_frames = [] + for frame in video_reader: + if mode == "gray": + import cv2 # Convert frame to grayscale if mode is gray + + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = np.expand_dims(frame, axis=2) # Keep frame dimensions consistent + video_frames.append(frame) + + return np.array(video_frames), video_reader.get_meta_data() diff --git a/cosmos_predict1/utils/lazy_config/__init__.py b/cosmos_predict1/utils/lazy_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3df830db623db39690c68ae09fa7e576cea5d0c --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/__init__.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from omegaconf import DictConfig, OmegaConf + +from cosmos_predict1.utils.lazy_config.instantiate import instantiate +from cosmos_predict1.utils.lazy_config.lazy import LazyCall, LazyConfig +from cosmos_predict1.utils.lazy_config.omegaconf_patch import to_object + +OmegaConf.to_object = to_object + +PLACEHOLDER = None +LazyDict = DictConfig + +__all__ = ["instantiate", "LazyCall", "LazyConfig", "PLACEHOLDER", "LazyDict"] + + +DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py + + +def fixup_module_metadata(module_name, namespace, keys=None): + """ + Fix the __qualname__ of module members to be their exported api name, so + when they are referenced in docs, sphinx can find them. Reference: + https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241 + """ + if not DOC_BUILDING: + return + seen_ids = set() + + def fix_one(qualname, name, obj): + # avoid infinite recursion (relevant when using + # typing.Generic, for example) + if id(obj) in seen_ids: + return + seen_ids.add(id(obj)) + + mod = getattr(obj, "__module__", None) + if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")): + obj.__module__ = module_name + # Modules, unlike everything else in Python, put fully-qualitied + # names into their __name__ attribute. We check for "." to avoid + # rewriting these. + if hasattr(obj, "__name__") and "." not in obj.__name__: + obj.__name__ = name + obj.__qualname__ = qualname + if isinstance(obj, type): + for attr_name, attr_value in obj.__dict__.items(): + fix_one(objname + "." + attr_name, attr_name, attr_value) + + if keys is None: + keys = namespace.keys() + for objname in keys: + if not objname.startswith("_"): + obj = namespace[objname] + fix_one(objname, objname, obj) + + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/cosmos_predict1/utils/lazy_config/file_io.py b/cosmos_predict1/utils/lazy_config/file_io.py new file mode 100644 index 0000000000000000000000000000000000000000..d9caf0081976dd08ab6ea1c04ad53304bc51d05d --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/file_io.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler +from iopath.common.file_io import PathManager as PathManagerBase + +__all__ = ["PathManager", "PathHandler"] + + +PathManager = PathManagerBase() +PathManager.register_handler(HTTPURLHandler()) +PathManager.register_handler(OneDrivePathHandler()) diff --git a/cosmos_predict1/utils/lazy_config/instantiate.py b/cosmos_predict1/utils/lazy_config/instantiate.py new file mode 100644 index 0000000000000000000000000000000000000000..3c87b7a555b2292468360b015b8100d09cf5cc59 --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/instantiate.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc as abc +import dataclasses +import logging +from typing import Any + +import attrs + +from cosmos_predict1.utils.lazy_config.registry import _convert_target_to_string, locate + +__all__ = ["dump_dataclass", "instantiate"] + + +def is_dataclass_or_attrs(target): + return dataclasses.is_dataclass(target) or attrs.has(target) + + +def dump_dataclass(obj: Any): + """ + Dump a dataclass recursively into a dict that can be later instantiated. + + Args: + obj: a dataclass object + + Returns: + dict + """ + assert dataclasses.is_dataclass(obj) and not isinstance( + obj, type + ), "dump_dataclass() requires an instance of a dataclass." + ret = {"_target_": _convert_target_to_string(type(obj))} + for f in dataclasses.fields(obj): + v = getattr(obj, f.name) + if dataclasses.is_dataclass(v): + v = dump_dataclass(v) + if isinstance(v, (list, tuple)): + v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v] + ret[f.name] = v + return ret + + +def instantiate(cfg, *args, **kwargs): + """ + Recursively instantiate objects defined in dictionaries by + "_target_" and arguments. + + Args: + cfg: a dict-like object with "_target_" that defines the caller, and + other keys that define the arguments + args: Optional positional parameters pass-through. + kwargs: Optional named parameters pass-through. + + Returns: + object instantiated by cfg + """ + from omegaconf import DictConfig, ListConfig, OmegaConf + + if isinstance(cfg, ListConfig): + lst = [instantiate(x) for x in cfg] + return ListConfig(lst, flags={"allow_objects": True}) + if isinstance(cfg, list): + # Specialize for list, because many classes take + # list[objects] as arguments, such as ResNet, DatasetMapper + return [instantiate(x) for x in cfg] + + # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config), + # instantiate it to the actual dataclass. + if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type): + return OmegaConf.to_object(cfg) + + if isinstance(cfg, abc.Mapping) and "_target_" in cfg: + # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all, + # but faster: https://github.com/facebookresearch/hydra/issues/1200 + cfg = {k: instantiate(v) for k, v in cfg.items()} + cls = cfg.pop("_target_") + cls = instantiate(cls) + + if isinstance(cls, str): + cls_name = cls + cls = locate(cls_name) + assert cls is not None, cls_name + else: + try: + cls_name = cls.__module__ + "." + cls.__qualname__ + except Exception: + # target could be anything, so the above could fail + cls_name = str(cls) + assert callable(cls), f"_target_ {cls} does not define a callable object" + try: + # override config with kwargs + instantiate_kwargs = {} + instantiate_kwargs.update(cfg) + instantiate_kwargs.update(kwargs) + return cls(*args, **instantiate_kwargs) + except TypeError: + logger = logging.getLogger(__name__) + logger.error(f"Error when instantiating {cls_name}!") + raise + return cfg # return as-is if don't know what to do diff --git a/cosmos_predict1/utils/lazy_config/lazy.py b/cosmos_predict1/utils/lazy_config/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..24063a35533ffcdba4b435dc388d87535ad4d330 --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/lazy.py @@ -0,0 +1,430 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import builtins +import collections.abc as abc +import importlib +import inspect +import logging +import os +import pickle +import uuid +from collections import OrderedDict +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import is_dataclass +from typing import Any, Dict, List, Tuple, Union + +import attrs +import yaml +from omegaconf import DictConfig, ListConfig, OmegaConf + +from cosmos_predict1.utils.lazy_config.file_io import PathManager +from cosmos_predict1.utils.lazy_config.registry import _convert_target_to_string + +try: + import dill as dill_pickle +except ImportError: + dill_pickle = None +try: + import cloudpickle +except ImportError: + cloudpickle = None + +__all__ = ["LazyCall", "LazyConfig"] + + +def sort_dict(d: Dict[str, Any]) -> OrderedDict[str, Any]: + return OrderedDict(sorted(d.items(), key=lambda x: x[0])) + + +def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode: + return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) + + +def sort_recursive(obj: Union[Dict[str, Any], List[Any], Any]) -> Union[OrderedDict[str, Any], List[Any], Any]: + if isinstance(obj, dict): + return sort_dict({k: sort_recursive(v) for k, v in obj.items()}) + elif isinstance(obj, list): + return [sort_recursive(item) for item in obj] + return obj + + +yaml.add_representer(OrderedDict, dict_representer) + + +def get_default_params(cls_or_func): + if callable(cls_or_func): + # inspect signature for function + signature = inspect.signature(cls_or_func) + else: + # inspect signature for class + signature = inspect.signature(cls_or_func.__init__) + params = signature.parameters + default_params = { + name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty + } + return default_params + + +class LazyCall: + """ + Wrap a callable so that when it's called, the call will not be executed, + but returns a dict that describes the call. + + LazyCall object has to be called with only keyword arguments. Positional + arguments are not yet supported. + + Examples: + :: + from detectron2.config import instantiate, LazyCall + + layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) + layer_cfg.out_channels = 64 # can edit it afterwards + layer = instantiate(layer_cfg) + """ + + def __init__(self, target): + if not (callable(target) or isinstance(target, (str, abc.Mapping))): + raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}") + self._target = target + + def __call__(self, **kwargs): + if is_dataclass(self._target) or attrs.has(self._target): + # omegaconf object cannot hold dataclass type + # https://github.com/omry/omegaconf/issues/784 + target = _convert_target_to_string(self._target) + else: + target = self._target + kwargs["_target_"] = target + + _final_params = get_default_params(self._target) + _final_params.update(kwargs) + + return DictConfig(content=_final_params, flags={"allow_objects": True}) + + +def _visit_dict_config(cfg, func): + """ + Apply func recursively to all DictConfig in cfg. + """ + if isinstance(cfg, DictConfig): + func(cfg) + for v in cfg.values(): + _visit_dict_config(v, func) + elif isinstance(cfg, ListConfig): + for v in cfg: + _visit_dict_config(v, func) + + +def _validate_py_syntax(filename): + # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py + with PathManager.open(filename, "r") as f: + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError(f"Config file {filename} has syntax error!") from e + + +def _cast_to_config(obj): + # if given a dict, return DictConfig instead + if isinstance(obj, dict): + return DictConfig(obj, flags={"allow_objects": True}) + return obj + + +_CFG_PACKAGE_NAME = "detectron2._cfg_loader" +""" +A namespace to put all imported config into. +""" + + +def _random_package_name(filename): + # generate a random package name when loading config files + return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename) + + +@contextmanager +def _patch_import(): + """ + Enhance relative import statements in config files, so that they: + 1. locate files purely based on relative location, regardless of packages. + e.g. you can import file without having __init__ + 2. do not cache modules globally; modifications of module states has no side effect + 3. support other storage system through PathManager, so config files can be in the cloud + 4. imported dict are turned into omegaconf.DictConfig automatically + """ + old_import = builtins.__import__ + + def find_relative_file(original_file, relative_import_path, level): + # NOTE: "from . import x" is not handled. Because then it's unclear + # if such import should produce `x` as a python module or DictConfig. + # This can be discussed further if needed. + relative_import_err = """ +Relative import of directories is not allowed within config files. +Within a config file, relative import can only import other config files. +""".replace( + "\n", " " + ) + if not len(relative_import_path): + raise ImportError(relative_import_err) + + cur_file = os.path.dirname(original_file) + for _ in range(level - 1): + cur_file = os.path.dirname(cur_file) + cur_name = relative_import_path.lstrip(".") + for part in cur_name.split("."): + cur_file = os.path.join(cur_file, part) + if not cur_file.endswith(".py"): + cur_file += ".py" + if not PathManager.isfile(cur_file): + cur_file_no_suffix = cur_file[: -len(".py")] + if PathManager.isdir(cur_file_no_suffix): + raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err) + else: + raise ImportError( + f"Cannot import name {relative_import_path} from " f"{original_file}: {cur_file} does not exist." + ) + return cur_file + + def new_import(name, globals=None, locals=None, fromlist=(), level=0): + if ( + # Only deal with relative imports inside config files + level != 0 + and globals is not None + and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME) + ): + cur_file = find_relative_file(globals["__file__"], name, level) + _validate_py_syntax(cur_file) + spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file) + module = importlib.util.module_from_spec(spec) + module.__file__ = cur_file + with PathManager.open(cur_file) as f: + content = f.read() + exec(compile(content, cur_file, "exec"), module.__dict__) + for name in fromlist: # turn imported dict into DictConfig automatically + val = _cast_to_config(module.__dict__[name]) + module.__dict__[name] = val + return module + return old_import(name, globals, locals, fromlist=fromlist, level=level) + + builtins.__import__ = new_import + yield new_import + builtins.__import__ = old_import + + +class LazyConfig: + """ + Provide methods to save, load, and overrides an omegaconf config object + which may contain definition of lazily-constructed objects. + """ + + @staticmethod + def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): + """ + Similar to :meth:`load()`, but load path relative to the caller's + source file. + + This has the same functionality as a relative import, except that this method + accepts filename as a string, so more characters are allowed in the filename. + """ + caller_frame = inspect.stack()[1] + caller_fname = caller_frame[0].f_code.co_filename + assert caller_fname != "", "load_rel Unable to find caller" + caller_dir = os.path.dirname(caller_fname) + filename = os.path.join(caller_dir, filename) + return LazyConfig.load(filename, keys) + + @staticmethod + def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): + """ + Load a config file. + + Args: + filename: absolute path or relative path w.r.t. the current working directory + keys: keys to load and return. If not given, return all keys + (whose values are config objects) in a dict. + """ + has_keys = keys is not None + filename = filename.replace("/./", "/") # redundant + if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]: + raise ValueError(f"Config file {filename} has to be a python or yaml file.") + if filename.endswith(".py"): + _validate_py_syntax(filename) + + with _patch_import(): + # Record the filename + module_namespace = { + "__file__": filename, + "__package__": _random_package_name(filename), + } + with PathManager.open(filename) as f: + content = f.read() + # Compile first with filename to: + # 1. make filename appears in stacktrace + # 2. make load_rel able to find its parent's (possibly remote) location + exec(compile(content, filename, "exec"), module_namespace) + + ret = module_namespace + else: + with PathManager.open(filename) as f: + obj = yaml.unsafe_load(f) + ret = OmegaConf.create(obj, flags={"allow_objects": True}) + + if has_keys: + if isinstance(keys, str): + return _cast_to_config(ret[keys]) + else: + return tuple(_cast_to_config(ret[a]) for a in keys) + else: + if filename.endswith(".py"): + # when not specified, only load those that are config objects + ret = DictConfig( + { + name: _cast_to_config(value) + for name, value in ret.items() + if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_") + }, + flags={"allow_objects": True}, + ) + return ret + + @staticmethod + def save_pkl(cfg, filename: str) -> str: + """ + Saves a Config object to a file using pickle serialization. This method is typically used + when the configuration object contains complex objects, such as lambdas, that are not supported by + simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration + object before serialization to ensure that the original object remains unmodified. + + Args: + cfg: A Config object to be serialized and saved. + filename: The path and name of the file where the configuration should be saved. The function + assumes the file extension indicates a pickle format (e.g., .pkl). + + Returns: + str: The filename to which the configuration was saved. This can be used to verify the file location + or log the outcome. + + Notes: + - The function logs a warning if the configuration is successfully saved using pickle. + - If saving fails, an error is logged with the exception details. + """ + logger = logging.getLogger(__name__) + try: + cfg = deepcopy(cfg) + except Exception: + pass + + try: + with PathManager.open(filename, "wb") as f: + pickle.dump(cfg, f) + logger.warning(f"Config is saved using pickle at {filename}.") + except Exception as e: + logger.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead") + if dill_pickle: + try: + with PathManager.open(filename, "wb") as f: + pickle.dump(dill_pickle.dumps(cfg, recurse=True), f) + logger.warning(f"Config is saved using dill at {filename}.") + except Exception as e: + logger.error(f"Failed to save config to {filename}: {e}.") + if cloudpickle: + try: + with PathManager.open(filename, "wb") as f: + pickle.dump(cloudpickle.dumps(cfg), f) + logger.warning(f"Config is saved using cloudpickle at {filename}.") + except Exception as e: + logger.error(f"Failed to save config to {filename}: {e}.") + else: + logger.error("cloudpickle is not available. Cannot save the config.") + raise e + + return filename + + @staticmethod + def save_yaml(cfg, filename: str) -> str: + """ + Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization. + + Args: + cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types. + filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'. + + Returns: + str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome. + + Notes: + - The function logs a warning if the configuration is successfully saved using YAML. + - If saving fails, an error is logged with the exception details. + """ + logger = logging.getLogger(__name__) + try: + cfg = deepcopy(cfg) + except Exception: + pass + + # Define a function to check if an item is serializable to YAML + def is_serializable(item): + try: + OmegaConf.to_yaml(item) + return True + except Exception as e: + return False + + # Function to convert unserializable items to strings + def serialize_config(config): + if isinstance(config, DictConfig): + for key, value in config.items(): + if isinstance(value, (DictConfig, ListConfig)): + try: + if "_target_" in value: + default_params = get_default_params(value["_target_"]) + for default_key, default_v in default_params.items(): + if default_key not in value: + value[default_key] = default_v + except Exception as e: + logger.error(f"Failed to add default argument values: {e}") + + serialize_config(value) + else: + if not is_serializable(value) and value is not None: + config[key] = str(value) + elif isinstance(config, ListConfig): + for i, item in enumerate(config): + if isinstance(item, (DictConfig, ListConfig)): + serialize_config(item) + else: + if not is_serializable(item) and item is not None: + config[i] = str(item) + else: + raise NotImplementedError("Input config must be a DictConfig or ListConfig.") + return config + + # Convert Config object to a DictConfig object. + config_dict = attrs.asdict(cfg) + config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) + + # Serialize the DictConfig object by converting non-serializable objects to strings. + config_omegaconf = serialize_config(config_omegaconf) + + config_dict: Dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True) + sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict) + with open(filename, "w") as f: + yaml.dump(sorted_config, f, default_flow_style=False) + logger.warning(f"Config is saved using omegaconf at {filename}.") + return filename diff --git a/cosmos_predict1/utils/lazy_config/omegaconf_patch.py b/cosmos_predict1/utils/lazy_config/omegaconf_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..39dca42a0a71383de919b750cedf2606faae206d --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/omegaconf_patch.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Union + +from omegaconf import OmegaConf +from omegaconf.base import DictKeyType, SCMode +from omegaconf.dictconfig import DictConfig # pragma: no cover + + +def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: + """ + Converts an OmegaConf configuration object to a native Python container (dict or list), unless + the configuration is specifically created by LazyCall, in which case the original configuration + is returned directly. + + This function serves as a modification of the original `to_object` method from OmegaConf, + preventing DictConfig objects created by LazyCall from being automatically converted to Python + dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended + structure and behavior. + + Differences from OmegaConf's original `to_object`: + - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall. + + Reference: + - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595 + + Args: + cfg (Any): The OmegaConf configuration object to convert. + + Returns: + Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if + `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`. + + Examples: + >>> cfg = DictConfig({"key": "value", "_target_": "Model"}) + >>> to_object(cfg) + DictConfig({"key": "value", "_target_": "Model"}) + + >>> cfg = DictConfig({"list": [1, 2, 3]}) + >>> to_object(cfg) + {'list': [1, 2, 3]} + """ + if isinstance(cfg, DictConfig) and "_target_" in cfg.keys(): + return cfg + + return OmegaConf.to_container( + cfg=cfg, + resolve=True, + throw_on_missing=True, + enum_to_str=False, + structured_config_mode=SCMode.INSTANTIATE, + ) diff --git a/cosmos_predict1/utils/lazy_config/registry.py b/cosmos_predict1/utils/lazy_config/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7c09eb428a97927d5f0407e2328a3f43afbf38fc --- /dev/null +++ b/cosmos_predict1/utils/lazy_config/registry.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pydoc +from typing import Any + +""" +`locate` provide ways to map a string (typically found +in config files) to callable objects. +""" + +__all__ = ["locate"] + + +def _convert_target_to_string(t: Any) -> str: + """ + Inverse of ``locate()``. + + Args: + t: any object with ``__module__`` and ``__qualname__`` + """ + module, qualname = t.__module__, t.__qualname__ + + # Compress the path to this object, e.g. ``module.submodule._impl.class`` + # may become ``module.submodule.class``, if the later also resolves to the same + # object. This simplifies the string, and also is less affected by moving the + # class implementation. + module_parts = module.split(".") + for k in range(1, len(module_parts)): + prefix = ".".join(module_parts[:k]) + candidate = f"{prefix}.{qualname}" + try: + if locate(candidate) is t: + return candidate + except ImportError: + pass + return f"{module}.{qualname}" + + +def locate(name: str) -> Any: + """ + Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``, + such as "module.submodule.class_name". + + Raise Exception if it cannot be found. + """ + obj = pydoc.locate(name) + + # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly + # by pydoc.locate. Try a private function from hydra. + if obj is None: + try: + # from hydra.utils import get_method - will print many errors + from hydra.utils import _locate + except ImportError as e: + raise ImportError(f"Cannot dynamically locate object {name}!") from e + else: + obj = _locate(name) # it raises if fails + + return obj diff --git a/cosmos_predict1/utils/log.py b/cosmos_predict1/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..45f98624193c5551c6c390dd0110d1440a610133 --- /dev/null +++ b/cosmos_predict1/utils/log.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import os +from typing import Any, Optional + +import torch.distributed as dist +from loguru._logger import Core, Logger +from tqdm import tqdm + +RANK0_ONLY = True +LEVEL = os.environ.get("LOGURU_LEVEL", "INFO") + +logger = Logger( + core=Core(), + exception=None, + depth=1, + record=False, + lazy=False, + colors=False, + raw=False, + capture=True, + patchers=[], + extra={}, +) + +atexit.register(logger.remove) + + +def _add_relative_path(record: dict[str, Any]) -> None: + start = os.getcwd() + record["extra"]["relative_path"] = os.path.relpath(record["file"].path, start) + + +*options, _, extra = logger._options # type: ignore +logger._options = tuple([*options, [_add_relative_path], extra]) # type: ignore + + +def init_loguru_stdout() -> None: + logger.remove() + machine_format = get_machine_format() + message_format = get_message_format() + logger.add( + lambda msg: tqdm.write(msg, end=""), # stdout is replaced with tqdm.write to avoid tqdm log pollution.. + level=LEVEL, + format="[{time:MM-DD HH:mm:ss}|" f"{machine_format}" f"{message_format}", + filter=_rank0_only_filter, + ) + + +def init_loguru_file(path: str) -> None: + machine_format = get_machine_format() + message_format = get_message_format() + logger.add( + path, + encoding="utf8", + level=LEVEL, + format="[{time:MM-DD HH:mm:ss}|" f"{machine_format}" f"{message_format}", + rotation="100 MB", + filter=lambda result: _rank0_only_filter(result) or not RANK0_ONLY, + enqueue=True, + ) + + +def get_machine_format() -> str: + node_id = os.environ.get("NGC_ARRAY_INDEX", "0") + num_nodes = int(os.environ.get("NGC_ARRAY_SIZE", "1")) + machine_format = "" + rank = 0 + if dist.is_available(): + if not RANK0_ONLY and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + machine_format = ( + f"[Node{node_id:<3}/{num_nodes:<3}][RANK{rank:<5}/{world_size:<5}]" + "[{process.name:<8}]| " + ) + return machine_format + + +def get_message_format() -> str: + message_format = "{level}|{extra[relative_path]}:{line}:{function}] {message}" + return message_format + + +def _rank0_only_filter(record: Any) -> bool: + is_rank0 = record["extra"].get("rank0_only", True) + if _get_rank() == 0 and is_rank0: + return True + if not is_rank0: + record["message"] = f"[RANK {_get_rank()}] " + record["message"] + return not is_rank0 + + +def trace(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message) + + +def debug(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message) + + +def info(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).info(message) + + +def success(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).success(message) + + +def warning(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message) + + +def error(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).error(message) + + +def critical(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message) + + +def exception(message: str, rank0_only: bool = True) -> None: + logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message) + + +def _get_rank(group: Optional[dist.ProcessGroup] = None) -> int: + """Get the rank (GPU device) of the worker. + + Returns: + rank (int): The rank of the worker. + """ + rank = 0 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank(group) + return rank + + +# Execute at import time. +init_loguru_stdout() diff --git a/cosmos_predict1/utils/misc.py b/cosmos_predict1/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..ce923cc656093ff4edd6c2a8bb01568c033795c9 --- /dev/null +++ b/cosmos_predict1/utils/misc.py @@ -0,0 +1,557 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import collections +import collections.abc +import functools +import json +import os +import random +import time +from contextlib import ContextDecorator +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, TypeVar +from urllib.parse import urlparse + +import boto3 +import numpy as np +import termcolor +import torch +from torch import nn +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed._tensor.api import DTensor + +from cosmos_predict1.utils import distributed, log +from cosmos_predict1.utils.easy_io import easy_io + + +def to( + data: Any, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + memory_format: torch.memory_format = torch.preserve_format, +) -> Any: + """Recursively cast data into the specified device, dtype, and/or memory_format. + + The input data can be a tensor, a list of tensors, a dict of tensors. + See the documentation for torch.Tensor.to() for details. + + Args: + data (Any): Input data. + device (str | torch.device): GPU device (default: None). + dtype (torch.dtype): data type (default: None). + memory_format (torch.memory_format): memory organization format (default: torch.preserve_format). + + Returns: + data (Any): Data cast to the specified device, dtype, and/or memory_format. + """ + assert ( + device is not None or dtype is not None or memory_format is not None + ), "at least one of device, dtype, memory_format should be specified" + if isinstance(data, torch.Tensor): + is_cpu = (isinstance(device, str) and device == "cpu") or ( + isinstance(device, torch.device) and device.type == "cpu" + ) + data = data.to( + device=device, + dtype=dtype, + memory_format=memory_format, + non_blocking=(not is_cpu), + ) + return data + elif isinstance(data, collections.abc.Mapping): + return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data}) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): + return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data]) + else: + return data + + +def serialize(data: Any) -> Any: + """Serialize data by hierarchically traversing through iterables. + + Args: + data (Any): Input data. + + Returns: + data (Any): Serialized data. + """ + if isinstance(data, collections.abc.Mapping): + return type(data)({key: serialize(data[key]) for key in data}) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): + return type(data)([serialize(elem) for elem in data]) + else: + try: + json.dumps(data) + except TypeError: + data = str(data) + return data + + +def print_environ_variables(env_vars: list[str]) -> None: + """Print a specific list of environment variables. + + Args: + env_vars (list[str]): List of specified environment variables. + """ + for env_var in env_vars: + if env_var in os.environ: + log.info(f"Environment variable {Color.green(env_var)}: {Color.yellow(os.environ[env_var])}") + else: + log.warning(f"Environment variable {Color.green(env_var)} not set!") + + +def set_random_seed(seed: int, by_rank: bool = False) -> None: + """Set random seed. This includes random, numpy, Pytorch. + + Args: + seed (int): Random seed. + by_rank (bool): if true, each GPU will use a different random seed. + """ + if by_rank: + seed += distributed.get_rank() + log.info(f"Using random seed {seed}.") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) # sets seed on the current CPU & all GPUs + + +def arch_invariant_rand( + shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None +): + """Produce a GPU-architecture-invariant randomized Torch tensor. + + Args: + shape (list or tuple of ints): Output tensor shape. + dtype (torch.dtype): Output tensor type. + device (torch.device): Device holding the output. + seed (int): Optional randomization seed. + + Returns: + tensor (torch.tensor): Randomly-generated tensor. + """ + # Create a random number generator, optionally seeded + rng = np.random.RandomState(seed) + + # # Generate random numbers using the generator + random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution + + # Convert to torch tensor and return + return torch.from_numpy(random_array).to(dtype=dtype, device=device) + + +T = TypeVar("T", bound=Callable[..., Any]) + + +class timer(ContextDecorator): # noqa: N801 + """Simple timer for timing the execution of code. + + It can be used as either a context manager or a function decorator. The timing result will be logged upon exit. + + Example: + def func_a(): + time.sleep(1) + with timer("func_a"): + func_a() + + @timer("func_b) + def func_b(): + time.sleep(1) + func_b() + """ + + def __init__(self, context: str, debug: bool = False): + self.context = context + self.debug = debug + + def __enter__(self) -> None: + self.tic = time.time() + + def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 + time_spent = time.time() - self.tic + if self.debug: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + else: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + + def __call__(self, func: T) -> T: + @functools.wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN202 + tic = time.time() + result = func(*args, **kwargs) + time_spent = time.time() - tic + if self.debug: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + else: + log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") + return result + + return wrapper # type: ignore + + +class TrainingTimer: + """Timer for timing the execution of code, aggregating over multiple training iterations. + + It is used as a context manager to measure the execution time of code and store the timing results + for each function. The context managers can be nested. + + Attributes: + results (dict): A dictionary to store timing results for various code. + + Example: + timer = Timer() + for i in range(100): + with timer("func_a"): + func_a() + avg_time = sum(timer.results["func_a"]) / len(timer.results["func_a"]) + print(f"func_a() took {avg_time} seconds.") + """ + + def __init__(self) -> None: + self.results = dict() + self.average_results = dict() + self.start_time = [] + self.func_stack = [] + self.reset() + + def reset(self) -> None: + self.results = {key: [] for key in self.results} + + def __enter__(self) -> TrainingTimer: + self.start_time.append(time.time()) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 + end_time = time.time() + result = end_time - self.start_time.pop() + key = self.func_stack.pop() + self.results.setdefault(key, []) + self.results[key].append(result) + + def __call__(self, func_name: str) -> TrainingTimer: + self.func_stack.append(func_name) + return self + + def __getattr__(self, func_name: str) -> TrainingTimer: + return self.__call__(func_name) + + def nested(self, func_name: str) -> TrainingTimer: + return self.__call__(func_name) + + def compute_average_results(self) -> dict[str, float]: + results = dict() + for key, value_list in self.results.items(): + results[key] = sum(value_list) / len(value_list) + return results + + +def timeout_handler(timeout_period: float, signum: int, frame: int) -> None: + # What to do when the process gets stuck. For now, we simply end the process. + error_message = f"Timeout error: more than {timeout_period} seconds passed since the last iteration." + raise TimeoutError(error_message) + + +class Color: + """A convenience class to colorize strings in the console. + + Example: + import + print("This is {Color.red('important')}.") + """ + + @staticmethod + def red(x: str) -> str: + return termcolor.colored(str(x), color="red") + + @staticmethod + def green(x: str) -> str: + return termcolor.colored(str(x), color="green") + + @staticmethod + def cyan(x: str) -> str: + return termcolor.colored(str(x), color="cyan") + + @staticmethod + def yellow(x: str) -> str: + return termcolor.colored(str(x), color="yellow") + + +class BufferCnt: + """ + Buffer counter which keeps track of the condition when called and returns True when the condition in met "thres" + amount of times, otherwise returns False. + + Example usage: + buf = BufferCnt(thres=3) + for _ in range(5): + if buf(random.random() > 0.5): + print("We got lucky 3 times out of 5.") + + Args: + thres (int): The amount of times the expression needs to be True before returning True. + reset_over_thres (bool): Whether to reset the buffer after returning True. + """ + + def __init__(self, thres=10, reset_over_thres=False): + self._cnt = 0 + self.thres = thres + self.reset_over_thres = reset_over_thres + + def __call__(self, expre, thres=None): + if expre is True: + self._cnt += 1 + else: + self._cnt = 0 + + if thres is None: + thres = self.thres + + if self._cnt >= thres: + if self.reset_over_thres: + self.reset() + return True + + return False + + @property + def cnt(self): + return self._cnt + + def reset(self): + self._cnt = 0 + + +def get_local_tensor_if_DTensor(tensor: torch.Tensor | DTensor) -> torch.tensor: + if isinstance(tensor, DTensor): + local = tensor.to_local() + # As per PyTorch documentation, if the communication is not finished yet, we need to wait for it to finish + # https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.DTensor.to_local + if isinstance(local, AsyncCollectiveTensor): + return local.wait() + else: + return local + return tensor + + +def disabled_train(self: Any, mode: bool = True) -> Any: + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def count_params(model: nn.Module, verbose=False) -> int: + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def expand_dims_like(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + while x.dim() != y.dim(): + x = x.unsqueeze(-1) + return x + + +def download_from_s3_with_cache( + s3_path: str, + cache_fp: Optional[str] = None, + cache_dir: Optional[str] = None, + rank_sync: bool = True, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, +) -> str: + """download data from S3 with optional caching. + + This function first attempts to load the data from a local cache file. If + the cache file doesn't exist, it downloads the data from S3 to the cache + location. Caching is performed in a rank-aware manner + using `distributed.barrier()` to ensure only one download occurs across + distributed workers (if `rank_sync` is True). + + Args: + s3_path (str): The S3 path of the data to load. + cache_fp (str, optional): The path to the local cache file. If None, + a filename will be generated based on `s3_path` within `cache_dir`. + cache_dir (str, optional): The directory to store the cache file. If + None, the environment variable `COSMOS_CACHE_DIR` (defaulting + to "/tmp") will be used. + rank_sync (bool, optional): Whether to synchronize download across + distributed workers using `distributed.barrier()`. Defaults to True. + backend_args (dict, optional): The backend arguments passed to easy_io to construct the backend. + backend_key (str, optional): The backend key passed to easy_io to registry the backend or retrieve the backend if it is already registered. + + Returns: + cache_fp (str): The path to the local cache file. + + Raises: + FileNotFoundError: If the data cannot be found in S3 or the cache. + """ + cache_dir = os.environ.get("TORCH_HOME") if cache_dir is None else cache_dir + cache_dir = ( + os.environ.get("COSMOS_CACHE_DIR", os.path.expanduser("~/.cache/cosmos")) if cache_dir is None else cache_dir + ) + cache_dir = os.path.expanduser(cache_dir) + if cache_fp is None: + cache_fp = os.path.join(cache_dir, s3_path.replace("s3://", "")) + if not cache_fp.startswith("/"): + cache_fp = os.path.join(cache_dir, cache_fp) + + if distributed.get_rank() == 0: + if os.path.exists(cache_fp): + # check the size of cache_fp + if os.path.getsize(cache_fp) < 1: + os.remove(cache_fp) + log.warning(f"Removed empty cache file {cache_fp}.") + + if rank_sync: + if not os.path.exists(cache_fp): + log.critical(f"Local cache {cache_fp} Not exist! Downloading {s3_path} to {cache_fp}.") + log.info(f"backend_args: {backend_args}") + log.info(f"backend_key: {backend_key}") + + easy_io.copyfile_to_local( + s3_path, cache_fp, dst_type="file", backend_args=backend_args, backend_key=backend_key + ) + log.info(f"Downloaded {s3_path} to {cache_fp}.") + else: + log.info(f"Local cache {cache_fp} already exist! {s3_path} -> {cache_fp}.") + + distributed.barrier() + else: + if not os.path.exists(cache_fp): + easy_io.copyfile_to_local( + s3_path, cache_fp, dst_type="file", backend_args=backend_args, backend_key=backend_key + ) + + log.info(f"Downloaded {s3_path} to {cache_fp}.") + return cache_fp + + +def load_from_s3_with_cache( + s3_path: str, + cache_fp: Optional[str] = None, + cache_dir: Optional[str] = None, + rank_sync: bool = True, + backend_args: Optional[dict] = None, + backend_key: Optional[str] = None, + easy_io_kwargs: Optional[dict] = None, +) -> Any: + """Loads data from S3 with optional caching. + + This function first attempts to load the data from a local cache file. If + the cache file doesn't exist, it downloads the data from S3 to the cache + location and then loads it. Caching is performed in a rank-aware manner + using `distributed.barrier()` to ensure only one download occurs across + distributed workers (if `rank_sync` is True). + + Args: + s3_path (str): The S3 path of the data to load. + cache_fp (str, optional): The path to the local cache file. If None, + a filename will be generated based on `s3_path` within `cache_dir`. + cache_dir (str, optional): The directory to store the cache file. If + None, the environment variable `COSMOS_CACHE_DIR` (defaulting + to "/tmp") will be used. + rank_sync (bool, optional): Whether to synchronize download across + distributed workers using `distributed.barrier()`. Defaults to True. + backend_args (dict, optional): The backend arguments passed to easy_io to construct the backend. + backend_key (str, optional): The backend key passed to easy_io to registry the backend or retrieve the backend if it is already registered. + + Returns: + Any: The loaded data from the S3 path or cache file. + + Raises: + FileNotFoundError: If the data cannot be found in S3 or the cache. + """ + cache_fp = download_from_s3_with_cache(s3_path, cache_fp, cache_dir, rank_sync, backend_args, backend_key) + + if easy_io_kwargs is None: + easy_io_kwargs = {} + return easy_io.load(cache_fp, **easy_io_kwargs) + + +def sync_s3_dir_to_local( + s3_dir: str, + s3_credential_path: str, + cache_dir: Optional[str] = None, + rank_sync: bool = True, +) -> str: + """ + Download an entire directory from S3 to the local cache directory. + + Args: + s3_dir (str): The AWS S3 directory to download. + s3_credential_path (str): The path to the AWS S3 credentials file. + rank_sync (bool, optional): Whether to synchronize download across + distributed workers using `distributed.barrier()`. Defaults to True. + cache_dir (str, optional): The cache folder to sync the S3 directory to. + If None, the environment variable `COSMOS_CACHE_DIR` (defaulting + to "~/.cache/cosmos") will be used. + + Returns: + local_dir (str): The path to the local directory. + """ + if not s3_dir.startswith("s3://"): + # If the directory exists locally, return the local path + assert os.path.exists(s3_dir), f"{s3_dir} is not a S3 path or a local path." + return s3_dir + + # Load AWS credentials from the file + with open(s3_credential_path, "r") as f: + credentials = json.load(f) + + # Create an S3 client + s3 = boto3.client( + "s3", + **credentials, + ) + + # Parse the S3 URL + parsed_url = urlparse(s3_dir) + source_bucket = parsed_url.netloc + source_prefix = parsed_url.path.lstrip("/") + + # If the local directory is not specified, use the default cache directory + cache_dir = ( + os.environ.get("COSMOS_CACHE_DIR", os.path.expanduser("~/.cache/cosmos")) if cache_dir is None else cache_dir + ) + cache_dir = os.path.expanduser(cache_dir) + Path(cache_dir).mkdir(parents=True, exist_ok=True) + + # List objects in the bucket with the given prefix + response = s3.list_objects_v2(Bucket=source_bucket, Prefix=source_prefix) + # Download each matching object + for obj in response.get("Contents", []): + if obj["Key"].startswith(source_prefix): + # Create the full path for the destination file, preserving the directory structure + rel_path = os.path.relpath(obj["Key"], source_prefix) + dest_path = os.path.join(cache_dir, source_prefix, rel_path) + + # Ensure the directory exists + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + + # Check if the file already exists + if os.path.exists(dest_path): + continue + else: + log.info(f"Downloading {obj['Key']} to {dest_path}") + # Download the file + if not rank_sync or distributed.get_rank() == 0: + s3.download_file(source_bucket, obj["Key"], dest_path) + if rank_sync: + distributed.barrier() + local_dir = os.path.join(cache_dir, source_prefix) + return local_dir diff --git a/cosmos_predict1/utils/model.py b/cosmos_predict1/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e6add316d05e17ca23ec6b9d7fb56ec744bbb220 --- /dev/null +++ b/cosmos_predict1/utils/model.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from cosmos_predict1.utils.lazy_config import LazyDict, instantiate + + +class Model(torch.nn.Module): + """The base model class. It is inherited from torch.nn.Module. + + All models should inherit Model. It should include the implementions for all the + computation graphs. All inheriting child classes should implement the following methods: + - training_step(): The training step of the model, including the loss computation. + - validation_step(): The validation step of the model, including the loss computation. + - forward(): The computation graph for model inference. + The following methods have default implementations in Model: + - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model. + """ + + def __init__(self) -> None: + super().__init__() + self.on_model_init_start(set_barrier=False) + + def init_optimizer_scheduler( + self, optimizer_config: LazyDict, scheduler_config: LazyDict + ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + """Creates the optimizer and scheduler for the model. + + Args: + config_model (ModelConfig): The config object for the model. + + Returns: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + """ + optimizer_config.params = self.parameters() + optimizer = instantiate(optimizer_config) + scheduler_config.optimizer = optimizer + scheduler = instantiate(scheduler_config) + return optimizer, scheduler + + def training_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """The training step of the model, including the loss computation. + + Args: + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + + Returns: + output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch. + loss (torch.Tensor): The total loss for backprop (weighted sum of various losses). + """ + raise NotImplementedError + + @torch.no_grad() + def validation_step( + self, data_batch: dict[str, torch.Tensor], iteration: int + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """The validation step of the model, including the loss computation. + + Args: + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + + Returns: + output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch. + loss (torch.Tensor): The total loss (weighted sum of various losses). + """ + raise NotImplementedError + + @torch.inference_mode() + def forward(self, *args: Any, **kwargs: Any) -> Any: + """The computation graph for model inference. + + Args: + *args: Whatever you decide to pass into the forward method. + **kwargs: Keyword arguments are also possible. + + Return: + Your model's output. + """ + raise NotImplementedError + + def on_model_init_start(self, set_barrier=False) -> None: + return + + def on_model_init_end(self, set_barrier=False) -> None: + return + + def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: + """The model preparation before the training is launched + + Args: + memory_format (torch.memory_format): Memory format of the model. + """ + pass + + def on_before_zero_grad( + self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int + ) -> None: + """Hook before zero_grad() is called. + + Args: + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + iteration (int): Current iteration number. + """ + pass + + def on_after_backward(self, iteration: int = 0) -> None: + """Hook after loss.backward() is called. + + This method is called immediately after the backward pass, allowing for custom operations + or modifications to be performed on the gradients before the optimizer step. + + Args: + iteration (int): Current iteration number. + """ + pass diff --git a/cosmos_predict1/utils/parallel_state_helper.py b/cosmos_predict1/utils/parallel_state_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..f531ab00c9d45a7dbf5015a43147bf635d72c5ec --- /dev/null +++ b/cosmos_predict1/utils/parallel_state_helper.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import parallel_state + + +def is_tp_cp_pp_rank0(): + return ( + parallel_state.get_tensor_model_parallel_rank() == 0 + and parallel_state.get_pipeline_model_parallel_rank() == 0 + and parallel_state.get_context_parallel_rank() == 0 + ) diff --git a/cosmos_predict1/utils/scheduler.py b/cosmos_predict1/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..d344b990dc55920947645b52d70e201c164f86d7 --- /dev/null +++ b/cosmos_predict1/utils/scheduler.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch + + +class WarmupLambdaLR(torch.optim.lr_scheduler.LambdaLR): + def __init__(self, optimizer, warmup, last_epoch=-1, verbose=False): + # Define the lambda function based on the warmup period + self.warmup = warmup + + def lr_lambda(epoch): + # Increase lr linearly for the first 'warmup' epochs + if epoch < warmup: + return float(epoch + 1) / warmup + # After 'warmup' epochs, keep lr constant + return 1.0 + + # Initialize the parent class with the generated lr_lambda + super(WarmupLambdaLR, self).__init__(optimizer, lr_lambda, last_epoch, verbose) + + +# cosine lr decay scheduler with warmup from https://github.com/karpathy/nanoGPT/blob/master/train.py#L228 +class WarmupCosineLR(torch.optim.lr_scheduler.LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_iters: int, + lr_decay_iters: int, + min_lr: float, + last_epoch: int = -1, + ): + self.warmup_iters = warmup_iters + self.lr_decay_iters = lr_decay_iters + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + # 1) linear warmup for warmup_iters steps + if self.last_epoch < self.warmup_iters: + return [base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs] + # 2) if it > lr_decay_iters, return min learning rate + if self.last_epoch > self.lr_decay_iters: + return [self.min_lr for _ in self.base_lrs] + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (self.last_epoch - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + return [self.min_lr + coeff * (base_lr - self.min_lr) for base_lr in self.base_lrs] diff --git a/cosmos_predict1/utils/trainer.py b/cosmos_predict1/utils/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca0b8349a0dc61dcedf2e603a3fbdf231d1badc --- /dev/null +++ b/cosmos_predict1/utils/trainer.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +import signal + +import torch +import torch.distributed as dist +import torch.utils.data +from megatron.core import parallel_state + +from cosmos_predict1.utils import callback, distributed, ema, log, misc +from cosmos_predict1.utils.checkpointer import Checkpointer +from cosmos_predict1.utils.lazy_config import LazyConfig, instantiate +from cosmos_predict1.utils.model import Model + + +class Trainer: + """The base trainer class. + + All trainers should inherit Trainer. It contains the basic functionality for model training + (particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA), + mixed-precision training (fp16/bf16). + + Attributes: + checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. + training_timer (misc.Timer): Timer object to time code blocks and functions. + """ + + def __init__(self, config): + """Constructor of the trainer. + + Args: + config (Config): The config object for the codebase. + """ + super().__init__() + self.config = config + # Set up the distributed computing environment. + with misc.timer("init_distributed"): + distributed.init() + # Set up parallel states. + if hasattr(config.model, "context_parallel_size"): + if config.model_parallel.context_parallel_size > 1: + raise ValueError( + "Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. " + "config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size." + ) + else: + log.critical( + "Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead." + ) + config.model_parallel.context_parallel_size = config.model.context_parallel_size + parallel_state.initialize_model_parallel( + pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size, + tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size, + context_parallel_size=config.model_parallel.context_parallel_size, + ) + # `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism. + # It is not part of the original `parallel_state` API, so we need to set it manually. + parallel_state.sequence_parallel = config.model_parallel.sequence_parallel + if parallel_state.sequence_parallel: + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + # Create the local job directory, save the config file, and pipe to a local log. + if distributed.is_rank0(): + os.makedirs(config.job.path_local, exist_ok=True) + # Save the config as .pkl for reproducibility. + LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl") + # Save the config as .yaml for reading or parsing experiment hyperparameters. + LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") + dist.barrier() + log.init_loguru_file(f"{config.job.path_local}/stdout.log") + if distributed.is_rank0(): + # Print important environment variables and the effective config. + log.info("Config:\n" + config.pretty_print(use_color=True)) + misc.print_environ_variables(["TORCH_HOME", "OUTPUT_ROOT"]) + # Set the random seed. If multi-GPU, different ranks are set with different seeds. + misc.set_random_seed(seed=config.trainer.seed, by_rank=True) + # Initialize cuDNN. + torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic + torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark + # Floating-point precision settings. + torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True + # Initialize the callback functions. + self.callbacks = callback.CallBackGroup(config=config, trainer=self) + # Initialize the model checkpointer. + if config.checkpoint.type is None: + self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) + else: + self.checkpointer: Checkpointer = instantiate( + config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks + ) + # Initialize the timer for speed benchmarking. + self.training_timer = misc.TrainingTimer() + # Send a TimeoutError if a training step takes over timeout_period seconds. + signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore + + def train( + self, + model: Model, + dataloader_train: torch.utils.data.DataLoader, + dataloader_val: torch.utils.data.DataLoader, + ) -> None: + """The training function. + + Args: + model (Model): The PyTorch model. + dataloader_train (torch.utils.data.DataLoader): The training data loader. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + """ + # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. + model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore + model.on_train_start(self.config.trainer.memory_format) + + # Initialize the optimizer, scheduler, and grad_scaler. + self.callbacks.on_optimizer_init_start() + optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) + grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) + self.callbacks.on_optimizer_init_end() + # Load the model checkpoint and get the starting iteration number. + iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) + grad_accum_iter = 0 + log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") + if self.config.trainer.distributed_parallelism == "ddp": + # Create a DDP model wrapper. + model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) + elif self.config.trainer.distributed_parallelism == "fsdp": + model_ddp = model + else: + raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") + log.info("Starting training...") + self.callbacks.on_train_start(model, iteration=iteration) + # Initial validation. + if self.config.trainer.run_validation and iteration == 0: + self.validate(model, dataloader_val, iteration=iteration) + _end_training = False + while True: + dataloader_train_iter = iter(dataloader_train) + while True: + self.callbacks.on_before_dataloading(iteration) + with self.training_timer("dataloader_train"): + try: + data_batch = next(dataloader_train_iter) + for k in data_batch.keys(): + if torch.is_tensor(data_batch[k]): + data_batch[k] = data_batch[k].cuda() + except StopIteration: + break + self.callbacks.on_after_dataloading(iteration) + # If max_iter is reached, exit the training loop. + if iteration >= self.config.trainer.max_iter: + _end_training = True + break + # Move all tensors in the data batch to GPU device. + data_batch = misc.to(data_batch, device="cuda") + # The actual training step. + self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) + if not model.training: + model_ddp.train() + assert model_ddp.training, "model_ddp is not in training mode." + assert model.training, "model is not in training mode." + output_batch, loss, grad_accum_iter = self.training_step( + model_ddp, + optimizer, + scheduler, + grad_scaler, + data_batch, + iteration=iteration, + grad_accum_iter=grad_accum_iter, + ) + # Do the following when an actual optimizer (update) step has been made. + iteration += 1 + # Save checkpoint. + if iteration % self.config.checkpoint.save_iter == 0: + async_saving = getattr(self.config.checkpoint, "async_saving", True) + self.checkpointer.save( + model, optimizer, scheduler, grad_scaler, iteration=iteration, async_saving=async_saving + ) + self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) + # Validation. + if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: + self.validate(model, dataloader_val, iteration=iteration) + # This iteration is successful; reset the timeout signal. + signal.alarm(self.config.trainer.timeout_period) + if _end_training: + break + log.success("Done with training.") + if iteration % self.config.checkpoint.save_iter != 0: + async_saving = getattr(self.config.checkpoint, "async_saving", True) + self.checkpointer.save( + model, optimizer, scheduler, grad_scaler, iteration=iteration, async_saving=async_saving + ) + self.callbacks.on_train_end(model, iteration=iteration) + self.checkpointer.finalize() + distributed.barrier() + self.callbacks.on_app_end() + + def training_step( + self, + model_ddp: torch.nn.Module | distributed.DistributedDataParallel, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + grad_scaler: torch.amp.GradScaler, + data: dict[str, torch.Tensor], + iteration: int = 0, + grad_accum_iter: int = 0, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: + """The training step. + + Args: + model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare + module, depending on whether distributed training is enabled or not. + optimizer (torch.optim.Optimizer): The model optimizer. + scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. + grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). + data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). + iteration (int): Current iteration number. + grad_accum_iter (int): Number of gradient accumulation iterations. + + Returns: + output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). + loss (torch.Tensor): The total loss of the training data batch. + """ + # Only let DDP sync gradient at the last iteration of the gradient accumulation window + with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): + with self.training_timer("forward"): + output_batch, loss = model_ddp.training_step(data, iteration) + self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) + with self.training_timer("backward"): + loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) + loss_scaled.backward() + if self.config.trainer.distributed_parallelism == "ddp": + model_ddp.module.on_after_backward() + else: + model_ddp.on_after_backward() + self.callbacks.on_after_backward(model_ddp, iteration=iteration) + grad_accum_iter += 1 + if grad_accum_iter == self.config.trainer.grad_accum_iter: + with self.training_timer("optimizer_step"): + self.callbacks.on_before_optimizer_step( + model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration + ) + grad_scaler.step(optimizer) + grad_scaler.update() + scheduler.step() + self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) + if self.config.trainer.distributed_parallelism == "ddp": + model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) + else: + model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) + optimizer.zero_grad(set_to_none=True) + grad_accum_iter = 0 + return output_batch, loss, grad_accum_iter + + @torch.no_grad() + def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: + """Validate on the full validation dataset. + + Args: + model (Model): The PyTorch model. + dataloader_val (torch.utils.data.DataLoader): The validation data loader. + iteration (int): Current iteration number. + """ + self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) + model.eval() + # Evaluate on the full validation set. + with ema.ema_scope(model, enabled=model.config.ema.enabled): + for val_iter, data_batch in enumerate(dataloader_val): + if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: + break + data_batch = misc.to(data_batch, device="cuda") + self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) + output_batch, loss = model.validation_step(data_batch, iteration) + self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) + self.callbacks.on_validation_end(model, iteration=iteration) diff --git a/cosmos_predict1/utils/validator.py b/cosmos_predict1/utils/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..837564a84d68079afb90870ed7181937e2f4df73 --- /dev/null +++ b/cosmos_predict1/utils/validator.py @@ -0,0 +1,503 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import base64 +import itertools +import json +import os +from abc import ABC, abstractmethod +from io import BytesIO +from typing import Any, List, Union + + +# from https://docs.python.org/3/howto/descriptor.html#validator-class +# For usage of hidden flag see the ModelParams class in apis/utils/model_params.py +class Validator(ABC): + # set name is called when the validator is created as class variable + # name is the name of the variable in the owner class, so here we create the name for the backing variable + def __set_name__(self, owner, name): + self.private_name = "_" + name + + def __get__(self, obj, objtype=None): + return getattr(obj, self.private_name, self.default) + + def __set__(self, obj, value): + value = self.validate(value) + setattr(obj, self.private_name, value) + + @abstractmethod + def validate(self, value): + pass + + def json(self): + pass + + +class MultipleOf(Validator): + def __init__(self, default: int, multiple_of: int, type_cast=None, hidden=False, tooltip=None): + if type(multiple_of) is not int: + raise ValueError(f"Expected {multiple_of!r} to be an int") + self.multiple_of = multiple_of + self.default = default + self.type_cast = type_cast + + # For usage of hidden flag see the ModelParams class in apis/utils/model_params.py + # if a parameter is hidden then probe() can't expose the param + # and the param can't be set anymore + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if self.type_cast: + try: + value = self.type_cast(value) + except ValueError: + raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") + + if value % self.multiple_of != 0: + raise ValueError(f"Expected {value!r} to be a multiple of {self.multiple_of!r}") + + return value + + def get_range_iterator(self): + return itertools.count(0, self.multiple_of) + + def __repr__(self) -> str: + return f"MultipleOf({self.private_name=} {self.multiple_of=} {self.hidden=})" + + def json(self): + return { + "type": MultipleOf.__name__, + "default": self.default, + "multiple_of": self.multiple_of, + "tooltip": self.tooltip, + } + + +class OneOf(Validator): + def __init__(self, default, options, type_cast=None, hidden=False, tooltip=None): + self.options = set(options) + self.default = default + self.type_cast = type_cast # Cast the value to this type before checking if it's in options + self.tooltip = tooltip + self.hidden = hidden + + def validate(self, value): + if self.type_cast: + try: + value = self.type_cast(value) + except ValueError: + raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") + + if value not in self.options: + raise ValueError(f"Expected {value!r} to be one of {self.options!r}") + + return value + + def get_range_iterator(self): + return self.options + + def __repr__(self) -> str: + return f"OneOf({self.private_name=} {self.options=} {self.hidden=})" + + def json(self): + return { + "type": OneOf.__name__, + "default": self.default, + "values": list(self.options), + "tooltip": self.tooltip, + } + + +class HumanAttributes(Validator): + def __init__(self, default, hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + # hard code the options for now + # we extend this to init parameter as needed + valid_attributes = { + "emotion": ["angry", "contemptful", "disgusted", "fearful", "happy", "neutral", "sad", "surprised"], + "race": ["asian", "indian", "black", "white", "middle eastern", "latino hispanic"], + "gender": ["male", "female"], + "age group": [ + "young", + "teen", + "adult early twenties", + "adult late twenties", + "adult early thirties", + "adult late thirties", + "adult middle aged", + "older adult", + ], + } + + def get_range_iterator(self): + # create a list of all possible combinations + l1 = self.valid_attributes["emotion"] + l2 = self.valid_attributes["race"] + l3 = self.valid_attributes["gender"] + l4 = self.valid_attributes["age group"] + all_combinations = list(itertools.product(l1, l2, l3, l4)) + return iter(all_combinations) + + def validate(self, value): + human_attributes = value.lower() + if human_attributes not in ["none", "random"]: + # In this case, we need for custom attribute string + + attr_string = human_attributes + for attr_key in ["emotion", "race", "gender", "age group"]: + attr_detected = False + for attr_label in self.valid_attributes[attr_key]: + if attr_string.startswith(attr_label): + attr_string = attr_string[len(attr_label) + 1 :] # noqa: E203 + attr_detected = True + break + + if attr_detected is False: + raise ValueError(f"Expected {value!r} to be one of {self.valid_attributes!r}") + + return value + + def __repr__(self) -> str: + return f"HumanAttributes({self.private_name=} {self.hidden=})" + + def json(self): + return { + "type": HumanAttributes.__name__, + "default": self.default, + "values": self.valid_attributes, + "tooltip": self.tooltip, + } + + +class Bool(Validator): + def __init__(self, default, hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if isinstance(value, int): + value = value != 0 + elif isinstance(value, str): + value = value.lower() + if value in ["true", "1"]: + value = True + elif value in ["false", "0"]: + value = False + else: + raise ValueError(f"Expected {value!r} to be one of ['True', 'False', '1', '0']") + elif not isinstance(value, bool): + raise TypeError(f"Expected {value!r} to be an bool") + + return value + + def get_range_iterator(self): + return [True, False] + + def __repr__(self) -> str: + return f"Bool({self.private_name=} {self.default=} {self.hidden=})" + + def json(self): + return { + "type": bool.__name__, + "default": self.default, + "tooltip": self.tooltip, + } + + +class Int(Validator): + def __init__(self, default, min=None, max=None, step=1, hidden=False, tooltip=None): + self.min = min + self.max = max + self.default = default + self.step = step + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if isinstance(value, str): + value = int(value) + elif not isinstance(value, int): + raise TypeError(f"Expected {value!r} to be an int") + + if self.min is not None and value < self.min: + raise ValueError(f"Expected {value!r} to be at least {self.min!r}") + if self.max is not None and value > self.max: + raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") + return value + + def get_range_iterator(self): + iter_min = self.min if self.min is not None else self.default + iter_max = self.max if self.max is not None else self.default + return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) + + def __repr__(self) -> str: + return f"Int({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" + + def json(self): + return { + "type": int.__name__, + "default": self.default, + "min": self.min, + "max": self.max, + "step": self.step, + "tooltip": self.tooltip, + } + + +class Float(Validator): + def __init__(self, default=0.0, min=None, max=None, step=0.5, hidden=False, tooltip=None): + self.min = min + self.max = max + self.default = default + self.step = step + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if isinstance(value, str) or isinstance(value, int): + value = float(value) + elif not isinstance(value, float): + raise TypeError(f"Expected {value!r} to be float") + + if self.min is not None and value < self.min: + raise ValueError(f"Expected {value!r} to be at least {self.min!r}") + if self.max is not None and value > self.max: + raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") + return value + + def get_range_iterator(self): + iter_min = self.min if self.min is not None else self.default + iter_max = self.max if self.max is not None else self.default + return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) + + def __repr__(self) -> str: + return f"Float({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" + + def json(self): + return { + "type": float.__name__, + "default": self.default, + "min": self.min, + "max": self.max, + "step": self.step, + "tooltip": self.tooltip, + } + + +class String(Validator): + def __init__(self, default="", min=None, max=None, predicate=None, hidden=False, tooltip=None): + self.min = min + self.max = max + self.predicate = predicate + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if not isinstance(value, str): + raise TypeError(f"Expected {value!r} to be an str") + if self.min is not None and len(value) < self.min: + raise ValueError(f"Expected {value!r} to be no smaller than {self.min!r}") + if self.max is not None and len(value) > self.max: + raise ValueError(f"Expected {value!r} to be no bigger than {self.max!r}") + if self.predicate is not None and not self.predicate(value): + raise ValueError(f"Expected {self.predicate} to be true for {value!r}") + return value + + def get_range_iterator(self): + return iter([self.default]) + + def __repr__(self) -> str: + return f"String({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" + + def json(self): + return { + "type": str.__name__, + "default": self.default, + "tooltip": self.tooltip, + } + + +class Path(Validator): + def __init__(self, default="", hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value): + if not isinstance(value, str): + raise TypeError(f"Expected {value!r} to be an str") + if not os.path.exists(value): + raise ValueError(f"Expected {value!r} to be a valid path") + + return value + + def get_range_iterator(self): + return iter([self.default]) + + def __repr__(self) -> str: + return f"String({self.private_name=} {self.default=}, {self.hidden=})" + + +class InputImage(Validator): + def __init__(self, default="", hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + valid_formats = { + "JPEG": ["jpeg", "jpg"], + "JPEG2000": ["jp2"], + "PNG": ["png"], + "GIF": ["gif"], + "BMP": ["bmp"], + } + + valid_extensions = {vi: k for k, v in valid_formats.items() for vi in v} + + def validate(self, value): + _, ext = os.path.splitext(value).lower() + image_format = InputImage.valid_extensions[ext] + + if not isinstance(value, str): + raise TypeError(f"Expected {value!r} to be an str") + if not os.path.exists(value): + raise ValueError(f"Expected {value!r} to be a valid path") + return value + + def get_range_iterator(self): + return iter([self.default]) + + def __repr__(self) -> str: + return f"String({self.private_name=} {self.default=} {self.hidden=})" + + def json(self): + return { + "type": InputImage.__name__, + "default": self.default, + "values": self.valid_formats, + "tooltip": self.tooltip, + } + + +class MeshFormat(Validator): + """ + Validator class for mesh formats. Valid inputs are either: + - single valid format such as "glb", "obj" + - or a list of valid formats such as "[obj, ply, usdz]" + """ + + valid_formats = {"glb", "usdz", "obj", "ply"} + + def __init__(self, default="glb", hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value: str) -> Union[str, List[str]]: + try: + # Attempt to parse the input as a Python list + if value.startswith("[") and value.endswith("]"): + formats = ast.literal_eval(value) + if not all(fmt in MeshFormat.valid_formats for fmt in formats): + raise ValueError(f"Each item must be one of {MeshFormat.valid_formats}") + return formats + elif value in MeshFormat.valid_formats: + return value + else: + raise ValueError(f"Expected {value!r} to be one of {MeshFormat.valid_formats} or a list of them") + except (SyntaxError, ValueError) as e: + # Handle case where the input is neither a valid single format nor a list of valid formats + raise ValueError(f"Invalid format specification: {value}. Error: {str(e)}") + + def __repr__(self) -> str: + return f"MeshFormat(default={self.default}, hidden={self.hidden})" + + def json(self): + return { + "type": MeshFormat.__name__, + "default": self.default, + "values": self.valid_formats, + "tooltip": self.tooltip, + } + + +class JsonDict(Validator): + """ + JSON stringified version of a python dict. + Example: '{"ema_customization_iter.pt": "ema_customization_iter.pt"}' + """ + + def __init__(self, default="", hidden=False): + self.default = default + self.hidden = hidden + + def validate(self, value): + if not value: + return {} + try: + dict = json.loads(value) + return dict + except json.JSONDecodeError as e: + raise ValueError(f"Expected {value!r} to be json stringified dict. Error: {str(e)}") + + def __repr__(self) -> str: + return f"Dict({self.default=} {self.hidden=})" + + +class BytesIOType(Validator): + """ + Validator class for BytesIO. Valid inputs are either: + - bytes + - objects of class BytesIO + - str which can be successfully decoded into BytesIO + """ + + def __init__(self, default=None, hidden=False, tooltip=None): + self.default = default + self.hidden = hidden + self.tooltip = tooltip + + def validate(self, value: Any) -> BytesIO: + if isinstance(value, str): + try: + # Decode the Base64 string + decoded_bytes = base64.b64decode(value) + # Create a BytesIO stream from the decoded bytes + return BytesIO(decoded_bytes) + except (base64.binascii.Error, ValueError) as e: + raise ValueError(f"Invalid Base64 encoded string: {e}") + elif isinstance(value, bytes): + return BytesIO(value) + elif isinstance(value, BytesIO): + return value + else: + raise TypeError(f"Expected {value!r} to be a Base64 encoded string, bytes, or BytesIO") + + def __repr__(self) -> str: + return f"BytesIOValidator({self.default=}, {self.hidden=})" + + def json(self): + return { + "type": BytesIO.__name__, + "default": self.default, + "tooltip": self.tooltip, + } diff --git a/cosmos_predict1/utils/visualize/video.py b/cosmos_predict1/utils/visualize/video.py new file mode 100644 index 0000000000000000000000000000000000000000..89fa021ccf8e85aa177f0f30293791c68531d8d0 --- /dev/null +++ b/cosmos_predict1/utils/visualize/video.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import IO, Any, Union + +import cv2 +import numpy as np +import torch +from einops import rearrange +from PIL import Image as PILImage +from torch import Tensor + +from cosmos_predict1.utils import log +from cosmos_predict1.utils.easy_io import easy_io + +try: + import ffmpegcv +except Exception as e: # ImportError cannot catch all problems + log.info(e) + ffmpegcv = None + + +def save_video(grid, video_name, fps=30): + grid = (grid * 255).astype(np.uint8) + grid = np.transpose(grid, (1, 2, 3, 0)) + with ffmpegcv.VideoWriter(video_name, "h264", fps) as writer: + for frame in grid: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + writer.write(frame) + + +def save_img_or_video(sample_C_T_H_W_in01: Tensor, save_fp_wo_ext: Union[str, IO[Any]], fps: int = 24) -> None: + """ + Save a tensor as an image or video file based on shape + + Args: + sample_C_T_H_W_in01 (Tensor): Input tensor with shape (C, T, H, W) in [0, 1] range. + save_fp_wo_ext (Union[str, IO[Any]]): File path without extension or file-like object. + fps (int): Frames per second for video. Default is 24. + """ + assert sample_C_T_H_W_in01.ndim == 4, "Only support 4D tensor" + assert isinstance(save_fp_wo_ext, str) or hasattr( + save_fp_wo_ext, "write" + ), "save_fp_wo_ext must be a string or file-like object" + + if torch.is_floating_point(sample_C_T_H_W_in01): + sample_C_T_H_W_in01 = sample_C_T_H_W_in01.clamp(0, 1) + else: + assert sample_C_T_H_W_in01.dtype == torch.uint8, "Only support uint8 tensor" + sample_C_T_H_W_in01 = sample_C_T_H_W_in01.float().div(255) + + if sample_C_T_H_W_in01.shape[1] == 1: + save_obj = PILImage.fromarray( + rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c 1 h w -> h w c").astype(np.uint8), + mode="RGB", + ) + ext = ".jpg" if isinstance(save_fp_wo_ext, str) else "" + easy_io.dump( + save_obj, + f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext, + file_format="jpg", + format="JPEG", + quality=85, + ) + else: + save_obj = rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c t h w -> t h w c").astype(np.uint8) + ext = ".mp4" if isinstance(save_fp_wo_ext, str) else "" + easy_io.dump( + save_obj, + f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext, + file_format="mp4", + format="mp4", + fps=fps, + ) diff --git a/datasets/README.md b/datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bf6e230528e0eca4899141db639951f9dab37ea9 --- /dev/null +++ b/datasets/README.md @@ -0,0 +1,3 @@ +### Datasets directory + +Datasets used to post-train cosmos models will be saved in this directory. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..aab370985d560e0ea8849c6047464672a59d4cf3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,35 @@ +attrs==25.1.0 +better-profanity==0.7.0 +boto3==1.35.99 +diffusers==0.32.2 +einops==0.8.1 +huggingface-hub==0.29.2 +hydra-core==1.3.2 +imageio[pyav,ffmpeg]==2.37.0 +iopath==0.1.10 +ipdb==0.13.13 +loguru==0.7.2 +mediapy==1.2.2 +megatron-core==0.10.0 +nltk==3.9.1 +numpy==1.26.4 +nvidia-ml-py==12.535.133 +omegaconf==2.3.0 +opencv-python==4.10.0.84 +pandas==2.2.3 +peft==0.14.0 +pillow==11.1.0 +protobuf==4.25.3 +pynvml==12.0.0 +pyyaml==6.0.2 +retinaface-py==0.0.2 +safetensors==0.5.3 +scikit-image==0.24.0 +sentencepiece==0.2.0 +setuptools==76.0.0 +termcolor==2.5.0 +torch==2.6.0 +torchvision==0.21.0 +tqdm==4.66.5 +transformers==4.49.0 +warp-lang==1.7.2 diff --git a/scripts/check_video_links.py b/scripts/check_video_links.py new file mode 100644 index 0000000000000000000000000000000000000000..63fcf13a2fe91c0b2c0fafa23f6017ed8af49576 --- /dev/null +++ b/scripts/check_video_links.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +import requests + + +def find_md_files(root="."): + for dirpath, _, filenames in os.walk(root): + for f in filenames: + if f.endswith(".md"): + yield os.path.join(dirpath, f) + + +def extract_video_urls(md_file): + with open(md_file, "r", encoding="utf-8") as f: + content = f.read() + return re.findall(r' 0: + print(f"Checkpoint {save_path} already exists and is not empty") + return + + pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409") + os.makedirs(pixtral_ckpt_dir, exist_ok=True) + repo_id = "mistralai/Pixtral-12B-2409" + print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...") + snapshot_download( + repo_id=repo_id, + allow_patterns=["params.json", "consolidated.safetensors"], + local_dir=pixtral_ckpt_dir, + local_dir_use_symlinks=False, + ) + orig_dtype = torch.get_default_dtype() + dtype = torch.bfloat16 + torch.set_default_dtype(dtype) + + # Load checkpoint file + ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors")) + assert len(ckpt_files) == 1, "ckpt_dir should contain only one file" + ckpt_path = ckpt_files[0] + ckpt = load_file(ckpt_path) + + # Split checkpoint into weights of vision encoder, projector, and LLM + vit_key_prefix = "vision_encoder." + vit_ckpt = {} + for key, value in ckpt.items(): + if key.startswith(vit_key_prefix): + vit_ckpt[key.lstrip(vit_key_prefix)] = value + + projector_key_prefix = "vision_language_adapter." + projector_ckpt = {} + substring_replacement_map = { + "w_in.": "projector.0.", + "w_out.": "projector.2.", + } + for key, value in ckpt.items(): + if key.startswith(projector_key_prefix): + key = key.lstrip(projector_key_prefix) + for old, new in substring_replacement_map.items(): + key = key.replace(old, new) + projector_ckpt[key] = value + + llm_ckpt = {} + for key, value in ckpt.items(): + if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix): + continue + llm_ckpt[key] = value + + vlm_ckpt = {} + for key, value in llm_ckpt.items(): + vlm_ckpt["model." + key] = value + for key, value in projector_ckpt.items(): + vlm_ckpt["mm_projector." + key] = value + for key, value in vit_ckpt.items(): + vlm_ckpt["vision_encoder." + key] = value + + # Load config + config_path = os.path.join(pixtral_ckpt_dir, "params.json") + with open(config_path, "r") as f: + pixtral_config = json.load(f) + + # Extract the vision encoder configuration + vision_encoder_config = { + "dim": pixtral_config["vision_encoder"]["hidden_size"], + "num_channels": pixtral_config["vision_encoder"]["num_channels"], + "image_size": pixtral_config["vision_encoder"]["image_size"], + "patch_size": pixtral_config["vision_encoder"]["patch_size"], + "rope_theta": pixtral_config["vision_encoder"]["rope_theta"], + "ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"], + "n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"], + "n_heads": pixtral_config["vision_encoder"]["num_attention_heads"], + "n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"], + "norm_type": "rmsnorm", + "norm_eps": pixtral_config["norm_eps"], + "image_token_id": pixtral_config["vision_encoder"]["image_token_id"], + } + # Configuration for the 400M ViT of Pixtral 12B VLM + vit_config = dict( + dim=1024, + num_channels=3, + image_size=1024, + patch_size=16, + rope_theta=10000, + ffn_hidden_size=4096, + n_layers=24, + n_heads=16, + n_kv_heads=16, + norm_type="rmsnorm", + norm_eps=1e-5, + image_token_id=10, + ) + # Compare the two configurations + for key, value in vit_config.items(): + assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}" + + llm_config_keys = [ + "dim", + "n_layers", + "head_dim", + "hidden_dim", + "n_heads", + "n_kv_heads", + "rope_theta", + "norm_eps", + "vocab_size", + ] + assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch" + replace_map = { + "hidden_dim": "ffn_hidden_size", + } + llm_config = {} + for k, v in pixtral_config.items(): + if k in llm_config_keys: + llm_config[replace_map.get(k, k)] = v + elif k == "vision_encoder": + llm_config["vision_encoder"] = vit_type + else: + raise ValueError(f"Unknown key: {k}") + + ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt} + torch.save(ckpt_to_save, save_path) + print(f"Model saved to {save_path}") + + # Save config + config_path = os.path.join(save_dir, "config.json") + with open(config_path, "w") as f: + json.dump(llm_config, f) + + torch.set_default_dtype(orig_dtype) # Reset the default dtype + + # Remove the original Pixtral checkpoint + shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True) + print(f"Removed {pixtral_ckpt_dir}") + + +MD5_CHECKSUM_LOOKUP = { + "Cosmos-Predict1-14B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-14B-Text2World/model.pt": "c69d1c6e51dc78b959040e8c4035a29b", + "Cosmos-Predict1-14B-Video2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-14B-Video2World/model.pt": "eaa7aa3678f61d88108c41d7fe201b18", + "Cosmos-Predict1-7B-WorldInterpolator/model.pt": "48a0bdc99d5e41eee05ba8597c4851da", + "Cosmos-Predict1-7B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-7B-Text2World/model.pt": "fe9ed68e16cf37b10e7414c9b3ee81e1", + "Cosmos-Predict1-7B-Video2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-7B-Video2World/model.pt": "ebcdb19c4c4a6a0e1e0bb65e346f6867", + "Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt": "f07680ad7eefae57d698778e2a0c7c96", + "Cosmos-Tokenize1-CV8x8x8-720p/image_mean_std.pt": "9f19fd3312fc1198e4905ada02e68bce", + "Cosmos-UpsamplePrompt1-12B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-UpsamplePrompt1-12B-Text2World/model.pt": "52d7a6b8b1ac44d856b4c1ea3f8c8c74", + "Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview/model.pt": "e3a6ef070deaae0678acd529dc749ea4", + "Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview/model.pt": "1653f87dce3d558ee01416593552a91c", + "google-t5/t5-11b/pytorch_model.bin": "f890878d8a162e0045a25196e27089a3", + "google-t5/t5-11b/tf_model.h5": "e081fc8bd5de5a6a9540568241ab8973", +} + + +def get_md5_checksum(checkpoints_dir, model_name): + print("---------------------") + # Check if there are any expected files for this model + expected_files = [key for key in MD5_CHECKSUM_LOOKUP if key.startswith(model_name + "/")] + if not expected_files: + # No expected files in MD5_CHECKSUM_LOOKUP, check if the directory exists and has content + model_dir = checkpoints_dir / model_name + if not model_dir.exists() or not any(model_dir.iterdir()): + print(f"Directory for {model_name} does not exist or is empty. Download required.") + return False + else: + print(f"Directory for {model_name} exists and contains files. Assuming download is complete.") + return True + # Proceed with checksum verification for models with expected files + for key, value in MD5_CHECKSUM_LOOKUP.items(): + if key.startswith(model_name + "/"): + print(f"Verifying checkpoint {key}...") + file_path = checkpoints_dir.joinpath(key) + # File must exist + if not Path(file_path).exists(): + print(f"Checkpoint {key} does not exist.") + return False + # File must match given MD5 checksum + with open(file_path, "rb") as f: + file_md5 = hashlib.md5(f.read()).hexdigest() + if file_md5 != value: + print(f"MD5 checksum of checkpoint {key} does not match.") + return False + print(f"Model checkpoints for {model_name} exist with matched MD5 checksums.") + return True + + +def main(args): + ORG_NAME = "nvidia" + + # Mapping from size argument to Hugging Face repository name + model_map = { + "7B": "Cosmos-Predict1-7B", + "14B": "Cosmos-Predict1-14B", + } + + # Additional models that are always downloaded + extra_models = [ + "Cosmos-Tokenize1-CV8x8x8-720p", + "google-t5/t5-11b", + ] + + if "Text2World" in args.model_types: + extra_models.append("Cosmos-UpsamplePrompt1-12B-Text2World") + + # Add interpolator if 7B model is selected + if "7B" in args.model_sizes: + extra_models.append("Cosmos-Predict1-7B-WorldInterpolator") + + # Create local checkpoints folder + checkpoints_dir = Path(args.checkpoint_dir) + checkpoints_dir.mkdir(parents=True, exist_ok=True) + + download_kwargs = dict( + allow_patterns=[ + "README.md", + "model.pt", + "mean_std.pt", + "image_mean_std.pt", + "config.json", + "*.jit", + "guardrail/*", + ] + ) + + # Download the requested diffusion models + for size in args.model_sizes: + for model_type in args.model_types: + suffix = f"-{model_type}" + model_name = model_map[size] + suffix + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + + if not get_md5_checksum(checkpoints_dir, model_name): + local_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading {repo_id} to {local_dir}...") + snapshot_download( + repo_id=repo_id, local_dir=str(local_dir), local_dir_use_symlinks=False, **download_kwargs + ) + + # Download the always-included models + for model_name in extra_models: + if model_name == "google-t5/t5-11b": + repo_id = model_name + else: + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + + if not get_md5_checksum(checkpoints_dir, model_name): + local_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading {repo_id} to {local_dir}...") + # Download all files for Guardrail + snapshot_download( + repo_id=repo_id, + local_dir=str(local_dir), + local_dir_use_symlinks=False, + ) + + if "Video2World" in args.model_types: + # Prompt Upsampler for Cosmos-Predict1-Video2World models + convert_pixtral_checkpoint( + checkpoint_dir=args.checkpoint_dir, + checkpoint_name="Pixtral-12B", + vit_type="pixtral-12b-vit", + ) + + download_guardrail_checkpoints(args.checkpoint_dir) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/download_diffusion_example_data.py b/scripts/download_diffusion_example_data.py new file mode 100644 index 0000000000000000000000000000000000000000..70bd4ba23a538d8f12dfa1fe7402ab9e33a7cd08 --- /dev/null +++ b/scripts/download_diffusion_example_data.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +import ffmpeg +from pytubefix import YouTube + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_diffusion_example_data.py --dataset_path datasets/hdvila --N_videos 128 --do_download --do_clip +""" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Download example (hdvila) data for posttraining") + parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") + parser.add_argument("--N_videos", type=int, default=128, help="Number of videos to download") + parser.add_argument("--do_download", action="store_true", help="Download the videos") + parser.add_argument("--do_clip", action="store_true", help="Clip the videos") + return parser.parse_args() + + +def convert_time_to_seconds(time_str) -> int: + h, m, s = map(float, time_str.split(":")) + ms = int(time_str.split(".")[-1]) if "." in time_str else 0 + return int(h * 3600 + m * 60 + s) + ms / 1000 + + +def download_data(args) -> None: + urls_set = set() + download_count = 0 + + videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") + os.makedirs(videos_orig_dir, exist_ok=True) + videos_dir = os.path.join(args.dataset_path, "videos") + os.makedirs(videos_dir, exist_ok=True) + metas_dir = os.path.join(args.dataset_path, "metas") + os.makedirs(metas_dir, exist_ok=True) + + hdvila_jsonl_path = os.path.join(args.dataset_path, "hdvila-100M.jsonl") + with open(hdvila_jsonl_path, "r") as fp: + for line in fp: + json_object = json.loads(line) + url = json_object["url"] + if url not in urls_set: # download videos with unique urls + yt = YouTube(json_object["url"]) + try: + # Download a video + yt.streams.get_highest_resolution().download( + output_path=videos_orig_dir, filename=json_object["video_id"] + ".mp4" + ) + download_count += 1 + urls_set.add(url) + print(f"Downloaded videos: {download_count}/{args.N_videos}") + + # Save metadata - caption and whole metadata + meta_txt_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".txt")) + with open(meta_txt_name, "w") as fp: + fp.write(json_object["caption"]) + meta_json_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".json")) + with open(meta_json_name, "w") as fp: + json.dump(json_object, fp) + except Exception as e: + print(e) + continue + + if len(urls_set) >= args.N_videos: + break + + +def clip_data(args) -> None: + videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") + videos_dir = os.path.join(args.dataset_path, "videos") + os.makedirs(videos_dir, exist_ok=True) + metas_dir = os.path.join(args.dataset_path, "metas") + + metas_list = [ + os.path.join(metas_dir, filename) for filename in sorted(os.listdir(metas_dir)) if filename.endswith(".json") + ] + videos_orig_list = [ + os.path.join(videos_orig_dir, filename) + for filename in sorted(os.listdir(videos_orig_dir)) + if filename.endswith(".mp4") + ] + + for meta_filename, video_orig_filename in zip(metas_list, videos_orig_list): + with open(meta_filename, "r") as fp: + metadata = json.load(fp) + + # Convert time strings to seconds + start_time = convert_time_to_seconds(metadata["span_start"]) + end_time = convert_time_to_seconds(metadata["span_end"]) + # Clip the video + clip_name = os.path.join(videos_dir, metadata["clip_id"]) + ffmpeg.input(video_orig_filename, ss=start_time, t=end_time - start_time).output(clip_name).run() + + +def main(args) -> None: + if args.do_download: + download_data(args) + if args.do_clip: + clip_data(args) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/download_gen3c_checkpoints.py b/scripts/download_gen3c_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..114efa8a9fe7880c574e52faeeaa36cf87c6c92f --- /dev/null +++ b/scripts/download_gen3c_checkpoints.py @@ -0,0 +1,312 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import hashlib +import json +import os +import shutil +from glob import glob +from pathlib import Path + +import torch +from huggingface_hub import snapshot_download +from safetensors.torch import load_file + +from scripts.download_guardrail_checkpoints import download_guardrail_checkpoints + + +def parse_args(): + parser = argparse.ArgumentParser(description="Download NVIDIA Cosmos Predict1 Gen3C models from Hugging Face") + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Directory to save the downloaded checkpoints." + ) + args = parser.parse_args() + return args + + +def convert_pixtral_checkpoint(checkpoint_dir: str, checkpoint_name: str, vit_type: str): + """ + Main function to convert Pixtral vision model weights to checkpoint and optionally verify and save the converted checkpoint. + + Args: + checkpoint_dir (str): Path to the checkpoint directory + checkpoint_name (str): Name of the checkpoint + vit_type (str): Type of ViT used in the Pixtral model + + This function performs the following steps: + 0. Download the checkpoint from Hugging Face + 1. Loads the original Pixtral checkpoint + 2. Splits the checkpoint into vision encoder, projector, and LLM weights + 3. Reorganizes the weights to match the expected format + 4. Extracts and verifies the vision encoder configuration + 5. Optionally verifies the converted checkpoint by loading it into a VisionTransformer + 6. Optionally saves the converted checkpoint and configuration + """ + + save_dir = os.path.join(checkpoint_dir, checkpoint_name) + os.makedirs(save_dir, exist_ok=True) + # Save the converted checkpoint + save_path = os.path.join(save_dir, "model.pt") + if os.path.exists(save_path) and os.path.getsize(save_path) > 0: + print(f"Checkpoint {save_path} already exists and is not empty") + return + + pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409") + os.makedirs(pixtral_ckpt_dir, exist_ok=True) + repo_id = "mistralai/Pixtral-12B-2409" + print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...") + snapshot_download( + repo_id=repo_id, + allow_patterns=["params.json", "consolidated.safetensors"], + local_dir=pixtral_ckpt_dir, + local_dir_use_symlinks=False, + ) + orig_dtype = torch.get_default_dtype() + dtype = torch.bfloat16 + torch.set_default_dtype(dtype) + + # Load checkpoint file + ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors")) + assert len(ckpt_files) == 1, "ckpt_dir should contain only one file" + ckpt_path = ckpt_files[0] + ckpt = load_file(ckpt_path) + + # Split checkpoint into weights of vision encoder, projector, and LLM + vit_key_prefix = "vision_encoder." + vit_ckpt = {} + for key, value in ckpt.items(): + if key.startswith(vit_key_prefix): + vit_ckpt[key.lstrip(vit_key_prefix)] = value + + projector_key_prefix = "vision_language_adapter." + projector_ckpt = {} + substring_replacement_map = { + "w_in.": "projector.0.", + "w_out.": "projector.2.", + } + for key, value in ckpt.items(): + if key.startswith(projector_key_prefix): + key = key.lstrip(projector_key_prefix) + for old, new in substring_replacement_map.items(): + key = key.replace(old, new) + projector_ckpt[key] = value + + llm_ckpt = {} + for key, value in ckpt.items(): + if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix): + continue + llm_ckpt[key] = value + + vlm_ckpt = {} + for key, value in llm_ckpt.items(): + vlm_ckpt["model." + key] = value + for key, value in projector_ckpt.items(): + vlm_ckpt["mm_projector." + key] = value + for key, value in vit_ckpt.items(): + vlm_ckpt["vision_encoder." + key] = value + + # Load config + config_path = os.path.join(pixtral_ckpt_dir, "params.json") + with open(config_path, "r") as f: + pixtral_config = json.load(f) + + # Extract the vision encoder configuration + vision_encoder_config = { + "dim": pixtral_config["vision_encoder"]["hidden_size"], + "num_channels": pixtral_config["vision_encoder"]["num_channels"], + "image_size": pixtral_config["vision_encoder"]["image_size"], + "patch_size": pixtral_config["vision_encoder"]["patch_size"], + "rope_theta": pixtral_config["vision_encoder"]["rope_theta"], + "ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"], + "n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"], + "n_heads": pixtral_config["vision_encoder"]["num_attention_heads"], + "n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"], + "norm_type": "rmsnorm", + "norm_eps": pixtral_config["norm_eps"], + "image_token_id": pixtral_config["vision_encoder"]["image_token_id"], + } + # Configuration for the 400M ViT of Pixtral 12B VLM + vit_config = dict( + dim=1024, + num_channels=3, + image_size=1024, + patch_size=16, + rope_theta=10000, + ffn_hidden_size=4096, + n_layers=24, + n_heads=16, + n_kv_heads=16, + norm_type="rmsnorm", + norm_eps=1e-5, + image_token_id=10, + ) + # Compare the two configurations + for key, value in vit_config.items(): + assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}" + + llm_config_keys = [ + "dim", + "n_layers", + "head_dim", + "hidden_dim", + "n_heads", + "n_kv_heads", + "rope_theta", + "norm_eps", + "vocab_size", + ] + assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch" + replace_map = { + "hidden_dim": "ffn_hidden_size", + } + llm_config = {} + for k, v in pixtral_config.items(): + if k in llm_config_keys: + llm_config[replace_map.get(k, k)] = v + elif k == "vision_encoder": + llm_config["vision_encoder"] = vit_type + else: + raise ValueError(f"Unknown key: {k}") + + ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt} + torch.save(ckpt_to_save, save_path) + print(f"Model saved to {save_path}") + + # Save config + config_path = os.path.join(save_dir, "config.json") + with open(config_path, "w") as f: + json.dump(llm_config, f) + + torch.set_default_dtype(orig_dtype) # Reset the default dtype + + # Remove the original Pixtral checkpoint + shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True) + print(f"Removed {pixtral_ckpt_dir}") + + +MD5_CHECKSUM_LOOKUP = { + "Cosmos-Predict1-14B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-14B-Text2World/model.pt": "c69d1c6e51dc78b959040e8c4035a29b", + "Cosmos-Predict1-14B-Video2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-14B-Video2World/model.pt": "eaa7aa3678f61d88108c41d7fe201b18", + "Cosmos-Predict1-7B-WorldInterpolator/model.pt": "48a0bdc99d5e41eee05ba8597c4851da", + "Cosmos-Predict1-7B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-7B-Text2World/model.pt": "fe9ed68e16cf37b10e7414c9b3ee81e1", + "Cosmos-Predict1-7B-Video2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-Predict1-7B-Video2World/model.pt": "ebcdb19c4c4a6a0e1e0bb65e346f6867", + "Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt": "f07680ad7eefae57d698778e2a0c7c96", + "Cosmos-Tokenize1-CV8x8x8-720p/image_mean_std.pt": "9f19fd3312fc1198e4905ada02e68bce", + "Cosmos-UpsamplePrompt1-12B-Text2World/guardrail/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf", + "Cosmos-UpsamplePrompt1-12B-Text2World/model.pt": "52d7a6b8b1ac44d856b4c1ea3f8c8c74", + "Cosmos-Predict1-7B-Text2World-Sample-AV-Multiview/model.pt": "e3a6ef070deaae0678acd529dc749ea4", + "Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview/model.pt": "1653f87dce3d558ee01416593552a91c", + "Gen3C-Cosmos-7B/model.pt": "38644bf823aa5272acef60cfad8bc0f7", + "google-t5/t5-11b/pytorch_model.bin": "f890878d8a162e0045a25196e27089a3", + "google-t5/t5-11b/tf_model.h5": "e081fc8bd5de5a6a9540568241ab8973", +} + + +def get_md5_checksum(checkpoints_dir, model_name): + print("---------------------") + # Check if there are any expected files for this model + expected_files = [key for key in MD5_CHECKSUM_LOOKUP if key.startswith(model_name + "/")] + if not expected_files: + # No expected files in MD5_CHECKSUM_LOOKUP, check if the directory exists and has content + model_dir = checkpoints_dir / model_name + if not model_dir.exists() or not any(model_dir.iterdir()): + print(f"Directory for {model_name} does not exist or is empty. Download required.") + return False + else: + print(f"Directory for {model_name} exists and contains files. Assuming download is complete.") + return True + # Proceed with checksum verification for models with expected files + for key, value in MD5_CHECKSUM_LOOKUP.items(): + if key.startswith(model_name + "/"): + print(f"Verifying checkpoint {key}...") + file_path = checkpoints_dir.joinpath(key) + # File must exist + if not Path(file_path).exists(): + print(f"Checkpoint {key} does not exist.") + return False + # File must match given MD5 checksum + with open(file_path, "rb") as f: + file_md5 = hashlib.md5(f.read()).hexdigest() + if file_md5 != value: + print(f"MD5 checksum of checkpoint {key} does not match.") + return False + print(f"Model checkpoints for {model_name} exist with matched MD5 checksums.") + return True + + +def main(args): + ORG_NAME = "nvidia" + + # Additional models that are always downloaded + extra_models = [ + "Cosmos-Tokenize1-CV8x8x8-720p", + "google-t5/t5-11b", + ] + + # Create local checkpoints folder + checkpoints_dir = Path(args.checkpoint_dir) + checkpoints_dir.mkdir(parents=True, exist_ok=True) + + download_kwargs = dict( + allow_patterns=[ + "README.md", + "model.pt", + "mean_std.pt", + "image_mean_std.pt", + "config.json", + "*.jit", + "guardrail/*", + ] + ) + + # Download the requested diffusion models + model_name = "Gen3C-Cosmos-7B" + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + + if not get_md5_checksum(checkpoints_dir, model_name): + local_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading {repo_id} to {local_dir}...") + snapshot_download( + repo_id=repo_id, local_dir=str(local_dir), local_dir_use_symlinks=False, **download_kwargs + ) + + # Download the always-included models + for model_name in extra_models: + if model_name == "google-t5/t5-11b": + repo_id = model_name + else: + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + + if not get_md5_checksum(checkpoints_dir, model_name): + local_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading {repo_id} to {local_dir}...") + # Download all files for Guardrail + snapshot_download( + repo_id=repo_id, + local_dir=str(local_dir), + local_dir_use_symlinks=False, + ) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/download_guardrail_checkpoints.py b/scripts/download_guardrail_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..81c4dafc13033b5ca8ce71b98ac28f51da510d0a --- /dev/null +++ b/scripts/download_guardrail_checkpoints.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List + +from huggingface_hub import snapshot_download + + +def download_models(models: List[str], destination_root: str): + """ + Download models from Hugging Face Hub and save them in org/project structure. + + Args: + models: List of model IDs in format 'org/project' + destination_root: Root directory where models will be saved + """ + for model_id in models: + model_id, revision = model_id.split(":") if ":" in model_id else (model_id, None) + print(f"Downloading {model_id}...") + + # Create the full path for the model + model_path = os.path.join(destination_root, model_id) + + try: + # Download the model + snapshot_download( + repo_id=model_id, + local_dir=model_path, + revision=revision, + ) + print(f"Successfully downloaded {model_id} to {model_path}") + + except Exception as e: + raise RuntimeError(f"Error downloading {model_id}: {str(e)}. Please delete the directory and try again.") + + +def download_guardrail_checkpoints(destination_root: str): + """ + Download guardrail checkpoints from Hugging Face Hub and save them in org/project structure. + + Args: + destination_root: Root directory where checkpoints will be saved + """ + # List of models to download + models_to_download = [ + "meta-llama/Llama-Guard-3-8B", + "nvidia/Cosmos-Guardrail1", + ] + + # Create the destination directory if it doesn't exist + os.makedirs(destination_root, exist_ok=True) + + # Download the models + download_models(models_to_download, destination_root) + + +if __name__ == "__main__": + download_guardrail_checkpoints("checkpoints") diff --git a/scripts/download_tokenizer_checkpoints.py b/scripts/download_tokenizer_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..249ee29cdad2ca648b691b74ff2ddb8b18c926b3 --- /dev/null +++ b/scripts/download_tokenizer_checkpoints.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import hashlib +import os +from pathlib import Path + +from huggingface_hub import snapshot_download + +from scripts.download_guardrail_checkpoints import download_guardrail_checkpoints + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="A script to download NVIDIA Cosmos-Tokenizer1 models from Hugging Face" + ) + parser.add_argument( + "--tokenizer_types", + nargs="*", + default=[ + "CV8x8x8-720p", + "DV8x16x16-720p", + "CI8x8-360p", + "CI16x16-360p", + "CV4x8x8-360p", + "DI8x8-360p", + "DI16x16-360p", + "DV4x8x8-360p", + ], # Download all by default + choices=[ + "CV8x8x8-720p", + "DV8x16x16-720p", + "CI8x8-360p", + "CI16x16-360p", + "CV4x8x8-360p", + "DI8x8-360p", + "DI16x16-360p", + "DV4x8x8-360p", + ], + help="Which tokenizer model types to download. Possible values: CV8x8x8-720p, DV8x16x16-720p, CV4x8x8-360p, DV4x8x8-360p", + ) + parser.add_argument( + "--checkpoint_dir", type=str, default="checkpoints", help="Directory to save the downloaded checkpoints." + ) + args = parser.parse_args() + return args + + +MD5_CHECKSUM_LOOKUP = { + "Cosmos-Tokenize1-CV8x8x8-720p/autoencoder.jit": "7f658580d5cf617ee1a1da85b1f51f0d", + "Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit": "ff21a63ed817ffdbe4b6841111ec79a8", + "Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit": "f5834d03645c379bc0f8ad14b9bc0299", + "Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt": "f07680ad7eefae57d698778e2a0c7c96", + "Cosmos-Tokenize1-CI16x16-360p/autoencoder.jit": "98f8fdf2ada5537705d6d1bc22c63cf1", + "Cosmos-Tokenize1-CI16x16-360p/decoder.jit": "dd31a73a8c7062bab25492401d83b473", + "Cosmos-Tokenize1-CI16x16-360p/encoder.jit": "7be1dadea5a1c283996ca1ce5b1a95a9", + "Cosmos-Tokenize1-CI8x8-360p/autoencoder.jit": "b2ff9280b12a97202641bb2a41d7b271", + "Cosmos-Tokenize1-CI8x8-360p/decoder.jit": "57fb213cd88c0a991e9d400875164571", + "Cosmos-Tokenize1-CI8x8-360p/encoder.jit": "138fe257df41d7a43c17396c23086565", + "Cosmos-Tokenize1-CV4x8x8-360p/autoencoder.jit": "0690ff725700128424d082b44a1eda08", + "Cosmos-Tokenize1-CV4x8x8-360p/decoder.jit": "7573744ec14cb1b2abdf9c80318b7224", + "Cosmos-Tokenize1-CV4x8x8-360p/encoder.jit": "fe3a7193defcb2db0b849b6df480b5e6", + "Cosmos-Tokenize1-CV8x8x8-720p/autoencoder.jit": "7f658580d5cf617ee1a1da85b1f51f0d", + "Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit": "ff21a63ed817ffdbe4b6841111ec79a8", + "Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit": "f5834d03645c379bc0f8ad14b9bc0299", + "Cosmos-Tokenize1-DI16x16-360p/autoencoder.jit": "88195130b86c3434d3d4b0e0376def6b", + "Cosmos-Tokenize1-DI16x16-360p/decoder.jit": "bf27a567388902acbd8abcc3a5afd8dd", + "Cosmos-Tokenize1-DI16x16-360p/encoder.jit": "12bae3a56c79a7ca0beb774843ee8c58", + "Cosmos-Tokenize1-DI8x8-360p/autoencoder.jit": "1d638e6034fcd43619bc1cdb343ebe56", + "Cosmos-Tokenize1-DI8x8-360p/decoder.jit": "b9b5eccaa7ab9ffbccae3b05b3903311", + "Cosmos-Tokenize1-DI8x8-360p/encoder.jit": "2bfa3c189aacdf9dc8faf17bcc30dd82", + "Cosmos-Tokenize1-DV4x8x8-360p/autoencoder.jit": "ff8802dc4497be60dc24a8f692833eed", + "Cosmos-Tokenize1-DV4x8x8-360p/decoder.jit": "f9a7d4bd24e4d2ee210cfd5f21550ce8", + "Cosmos-Tokenize1-DV4x8x8-360p/encoder.jit": "7af30a0223b2984d9d27dd3054fcd7af", + "Cosmos-Tokenize1-DV8x16x16-720p/autoencoder.jit": "606b8585b637f06057725cbb67036ae6", + "Cosmos-Tokenize1-DV8x16x16-720p/decoder.jit": "f0c8a9d992614a43e7ce24ebfc901e26", + "Cosmos-Tokenize1-DV8x16x16-720p/encoder.jit": "95186b0410346a3f0cf250b76daec452", +} + + +def get_md5_checksum(checkpoints_dir, model_name): + print("---------------------") + for key, value in MD5_CHECKSUM_LOOKUP.items(): + if key.startswith(model_name): + print(f"Verifying checkpoint {key}...") + file_path = checkpoints_dir.joinpath(key) + # File must exist + if not Path(file_path).exists(): + print(f"Checkpoint {key} does not exist.") + return False + # File must match give MD5 checksum + with open(file_path, "rb") as f: + file_md5 = hashlib.md5(f.read()).hexdigest() + if file_md5 != value: + print(f"MD5 checksum of checkpoint {key} does not match.") + return False + print(f"Model checkpoints for {model_name} exist with matched MD5 checksums.") + return True + + +def main(args) -> None: + ORG_NAME = "nvidia" + + # Mapping from size argument to Hugging Face repository name + model_map = { + "CV8x8x8-720p": "Cosmos-Tokenize1-CV8x8x8-720p", + "DV8x16x16-720p": "Cosmos-Tokenize1-DV8x16x16-720p", + "CI8x8-360p": "Cosmos-Tokenize1-CI8x8-360p", + "CI16x16-360p": "Cosmos-Tokenize1-CI16x16-360p", + "CV4x8x8-360p": "Cosmos-Tokenize1-CV4x8x8-360p", + "DI8x8-360p": "Cosmos-Tokenize1-DI8x8-360p", + "DI16x16-360p": "Cosmos-Tokenize1-DI16x16-360p", + "DV4x8x8-360p": "Cosmos-Tokenize1-DV4x8x8-360p", + } + + # Create local checkpoints folder + checkpoints_dir = Path(args.checkpoint_dir) + checkpoints_dir.mkdir(parents=True, exist_ok=True) + + download_kwargs = dict(allow_patterns=["README.md", "model.pt", "mean_std.pt", "config.json", "*.jit"]) + + # Download the requested Tokenizer models + for tokenizer_type in args.tokenizer_types: + model_name = model_map[tokenizer_type] + repo_id = f"{ORG_NAME}/{model_name}" + local_dir = checkpoints_dir.joinpath(model_name) + + if not get_md5_checksum(checkpoints_dir, model_name): + local_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading {repo_id} to {local_dir}...") + snapshot_download( + repo_id=repo_id, local_dir=str(local_dir), local_dir_use_symlinks=False, **download_kwargs + ) + + download_guardrail_checkpoints(args.checkpoint_dir) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/download_tokenizer_example_data.py b/scripts/download_tokenizer_example_data.py new file mode 100644 index 0000000000000000000000000000000000000000..28b7f2257f3f272cc29c6163ddcc8cd7e46ef2d0 --- /dev/null +++ b/scripts/download_tokenizer_example_data.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +import ffmpeg +from pytubefix import YouTube + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_tokenizer_example_data.py --dataset_path datasets/hdvila --N_videos 128 --do_download --do_clip +""" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Download example (hdvila) data for posttraining") + parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") + parser.add_argument("--N_videos", type=int, default=128, help="Number of videos to download") + parser.add_argument("--do_download", action="store_true", help="Download the videos") + parser.add_argument("--do_clip", action="store_true", help="Clip the videos") + return parser.parse_args() + + +def convert_time_to_seconds(time_str) -> int: + h, m, s = map(float, time_str.split(":")) + ms = int(time_str.split(".")[-1]) if "." in time_str else 0 + return int(h * 3600 + m * 60 + s) + ms / 1000 + + +def download_data(args) -> None: + urls_set = set() + download_count = 0 + + videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") + os.makedirs(videos_orig_dir, exist_ok=True) + videos_dir = os.path.join(args.dataset_path, "videos") + os.makedirs(videos_dir, exist_ok=True) + metas_dir = os.path.join(args.dataset_path, "metas") + os.makedirs(metas_dir, exist_ok=True) + + hdvila_jsonl_path = os.path.join(args.dataset_path, "hdvila-100M.jsonl") + with open(hdvila_jsonl_path, "r") as fp: + for line in fp: + json_object = json.loads(line) + url = json_object["url"] + if url not in urls_set: # download videos with unique urls + yt = YouTube(json_object["url"]) + try: + # Download a video + yt.streams.get_highest_resolution().download( + output_path=videos_orig_dir, filename=json_object["video_id"] + ".mp4" + ) + download_count += 1 + urls_set.add(url) + print(f"Downloaded videos: {download_count}/{args.N_videos}") + + # Save metadata - caption and whole metadata + meta_txt_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".txt")) + with open(meta_txt_name, "w") as fp: + fp.write(json_object["caption"]) + meta_json_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".json")) + with open(meta_json_name, "w") as fp: + json.dump(json_object, fp) + except Exception as e: + print(e) + continue + + if len(urls_set) >= args.N_videos: + break + + +def clip_data(args) -> None: + videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") + videos_dir = os.path.join(args.dataset_path, "videos") + os.makedirs(videos_dir, exist_ok=True) + metas_dir = os.path.join(args.dataset_path, "metas") + + metas_list = [ + os.path.join(metas_dir, filename) for filename in sorted(os.listdir(metas_dir)) if filename.endswith(".json") + ] + videos_orig_list = [ + os.path.join(videos_orig_dir, filename) + for filename in sorted(os.listdir(videos_orig_dir)) + if filename.endswith(".mp4") + ] + + for meta_filename, video_orig_filename in zip(metas_list, videos_orig_list): + with open(meta_filename, "r") as fp: + metadata = json.load(fp) + + # Convert time strings to seconds + start_time = convert_time_to_seconds(metadata["span_start"]) + end_time = convert_time_to_seconds(metadata["span_end"]) + # Clip the video + clip_name = os.path.join(videos_dir, metadata["clip_id"]) + ffmpeg.input(video_orig_filename, ss=start_time, t=end_time - start_time).output(clip_name).run() + + +def main(args) -> None: + if args.do_download: + download_data(args) + if args.do_clip: + clip_data(args) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/format.sh b/scripts/format.sh new file mode 100644 index 0000000000000000000000000000000000000000..42c3e2bd41407ab284b14bf2cb0bc76d67785374 --- /dev/null +++ b/scripts/format.sh @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cosmos_root=$(git rev-parse --show-toplevel) +venv_folder=$cosmos_root/.venv +scripts_folder=$cosmos_root/scripts + +echo "Formatting $cosmos_root" +if [ ! -d "$scripts_folder" ]; then + echo "script has to be called from repo root dir!" + exit -1 +fi + +if [ ! -d "$venv_folder" ]; then + mkdir -p $venv_folder + python3 -m pip install virtualenv + python3 -m venv $venv_folder +fi + +source $venv_folder/bin/activate + +dependencies=($(pip freeze | grep -E 'pre-commit==3.7.1|flake8==7.1.0|black==24.4.2|isort==5.13.2|loguru|termcolor')) +if [ "${#dependencies[@]}" -ne 6 ]; then + python3 -m pip install --upgrade pip + python3 -m pip install pre-commit==3.7.1 + python3 -m pip install flake8==7.1.0 + python3 -m pip install black==24.4.2 + python3 -m pip install isort==5.13.2 + python3 -m pip install loguru + python3 -m pip install termcolor +fi +set -e +python3 $scripts_folder/ip_header.py +pre-commit install-hooks +pre-commit run --all diff --git a/scripts/get_t5_embeddings.py b/scripts/get_t5_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..53b6ebe889340e9360a946fa3022766dcf027c36 --- /dev/null +++ b/scripts/get_t5_embeddings.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pickle +from typing import Tuple + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings.py --dataset_path datasets/hdvila +""" + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") + parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") + parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument( + "--pretrained_model_name_or_path", type=str, default="google-t5/t5-11b", help="T5 model name or the local path" + ) + parser.add_argument("--cache_dir", type=str, default="checkpoints", help="Directory to cache the T5 model") + return parser.parse_args() + + +def init_t5( + pretrained_model_name_or_path: str = "google-t5/t5-11b", max_length: int = 512, cache_dir: str = "~/.cache" +) -> Tuple[T5TokenizerFast, T5EncoderModel]: + """Initialize and return the T5 tokenizer and text encoder.""" + tokenizer = T5TokenizerFast.from_pretrained( + pretrained_model_name_or_path, model_max_length=max_length, cache_dir=cache_dir + ) + text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) + text_encoder.to("cuda") + text_encoder.eval() + return tokenizer, text_encoder + + +@torch.inference_mode() +def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512) -> list: + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + encoded_text = encoded_text.cpu().numpy().astype(np.float16) + encoded_text = encoded_text[:, :max_length] + + # trim zeros to save space + encoded_text = [encoded_text[batch_id][: lengths[batch_id]] for batch_id in range(encoded_text.shape[0])] + + return encoded_text + + +def main(args) -> None: + metas_dir = os.path.join(args.dataset_path, "metas") + metas_list = [ + os.path.join(metas_dir, filename) for filename in sorted(os.listdir(metas_dir)) if filename.endswith(".txt") + ] + + t5_xxl_dir = os.path.join(args.dataset_path, "t5_xxl") + os.makedirs(t5_xxl_dir, exist_ok=True) + + # Initialize T5 + tokenizer, text_encoder = init_t5(cache_dir=args.cache_dir) + + for meta_filename in metas_list: + t5_xxl_filename = os.path.join(t5_xxl_dir, os.path.basename(meta_filename).replace(".txt", ".pickle")) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + with open(meta_filename, "r") as fp: + prompt = fp.read().strip() + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/get_t5_embeddings_from_bridge.py b/scripts/get_t5_embeddings_from_bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcdfe461914b86e61dbbac9c060f41938c7ed60 --- /dev/null +++ b/scripts/get_t5_embeddings_from_bridge.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import pickle +from typing import Tuple + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings_from_bridge.py --dataset_path datasets/bridge +""" + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") + parser.add_argument("--dataset_path", type=str, default="datasets/bridge", help="Root path to the dataset") + parser.add_argument( + "--subset", + type=str, + default="train", + choices=("train", "val", "test"), + help="Subset of the bridge dataset to process", + ) + parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument( + "--pretrained_model_name_or_path", type=str, default="google-t5/t5-11b", help="T5 model name or the local path" + ) + parser.add_argument("--cache_dir", type=str, default="checkpoints", help="Directory to cache the T5 model") + return parser.parse_args() + + +def init_t5( + pretrained_model_name_or_path: str = "google-t5/t5-11b", max_length: int = 512, cache_dir: str = "~/.cache" +) -> Tuple[T5TokenizerFast, T5EncoderModel]: + """Initialize and return the T5 tokenizer and text encoder.""" + tokenizer = T5TokenizerFast.from_pretrained( + pretrained_model_name_or_path, model_max_length=max_length, cache_dir=cache_dir + ) + text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) + text_encoder.to("cuda") + text_encoder.eval() + return tokenizer, text_encoder + + +@torch.inference_mode() +def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512) -> list: + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + encoded_text = encoded_text.cpu().numpy().astype(np.float16) + encoded_text = encoded_text[:, :max_length] + + # trim zeros to save space + encoded_text = [encoded_text[batch_id][: lengths[batch_id]] for batch_id in range(encoded_text.shape[0])] + + return encoded_text + + +def main(args) -> None: + annotation_dir = os.path.join(args.dataset_path, "annotation", args.subset) + annotation_list = [ + os.path.join(annotation_dir, filename) + for filename in sorted(os.listdir(annotation_dir)) + if filename.endswith(".json") + ] + + # Initialize T5 + tokenizer, text_encoder = init_t5(cache_dir=args.cache_dir) + + for annotation_filename in annotation_list: + # Save T5 embeddings as pickle file + t5_xxl_filename = os.path.join( + annotation_dir, os.path.basename(annotation_filename).replace(".json", ".pickle") + ) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + with open(annotation_filename, "r") as fp: + metadata = json.load(fp) + prompt = metadata["texts"][0] + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py b/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e8783c5058c8daba8f3ee6a1d93c321f532f25 --- /dev/null +++ b/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pickle +from typing import Tuple + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings_from_cosmos_nemo_assets.py --dataset_path datasets/cosmos_nemo_assets +""" + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") + parser.add_argument( + "--dataset_path", type=str, default="datasets/cosmos_nemo_assets", help="Root path to the dataset" + ) + parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument( + "--pretrained_model_name_or_path", type=str, default="google-t5/t5-11b", help="T5 model name or the local path" + ) + parser.add_argument("--prompt", type=str, default="A video of sks teal robot.", help="Text prompt for the dataset") + parser.add_argument("--cache_dir", type=str, default="checkpoints", help="Directory to cache the T5 model") + return parser.parse_args() + + +def init_t5( + pretrained_model_name_or_path: str = "google-t5/t5-11b", max_length: int = 512, cache_dir: str = "~/.cache" +) -> Tuple[T5TokenizerFast, T5EncoderModel]: + """Initialize and return the T5 tokenizer and text encoder.""" + tokenizer = T5TokenizerFast.from_pretrained( + pretrained_model_name_or_path, model_max_length=max_length, cache_dir=cache_dir + ) + text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) + text_encoder.to("cuda") + text_encoder.eval() + return tokenizer, text_encoder + + +@torch.inference_mode() +def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512) -> list: + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + encoded_text = encoded_text.cpu().numpy().astype(np.float16) + encoded_text = encoded_text[:, :max_length] + + # trim zeros to save space + encoded_text = [encoded_text[batch_id][: lengths[batch_id]] for batch_id in range(encoded_text.shape[0])] + + return encoded_text + + +def main(args) -> None: + videos_dir = os.path.join(args.dataset_path, "videos") + + # Cosmos-NeMo-Assets come with videos only. A prompt is provided as an argument. + metas_dir = os.path.join(args.dataset_path, "metas") + os.makedirs(metas_dir, exist_ok=True) + metas_list = [ + os.path.join(metas_dir, filename.replace(".mp4", ".txt")) + for filename in sorted(os.listdir(videos_dir)) + if filename.endswith(".mp4") + ] + + # Write txt files to match other dataset formats. + for meta_filename in metas_list: + if not os.path.exists(meta_filename): + with open(meta_filename, "w") as fp: + fp.write(args.prompt) + + t5_xxl_dir = os.path.join(args.dataset_path, "t5_xxl") + os.makedirs(t5_xxl_dir, exist_ok=True) + + # Initialize T5 + tokenizer, text_encoder = init_t5(cache_dir=args.cache_dir) + + for meta_filename in metas_list: + t5_xxl_filename = os.path.join(t5_xxl_dir, os.path.basename(meta_filename).replace(".txt", ".pickle")) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + with open(meta_filename, "r") as fp: + prompt = fp.read().strip() + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/get_t5_embeddings_from_waymo.py b/scripts/get_t5_embeddings_from_waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..a33d80ca1d01b734daff570c12cafa72f610ad4d --- /dev/null +++ b/scripts/get_t5_embeddings_from_waymo.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pickle +from typing import Tuple + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +"""example command +CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings_from_waymo.py --dataset_path datasets/waymo +""" + +PREFIX_PROMPTS = { + "pinhole_front": "The video is captured from a camera mounted on a car. The camera is facing forward.", + "pinhole_front_left": "The video is captured from a camera mounted on a car. The camera is facing forward and slightly to the left.", + "pinhole_front_right": "The video is captured from a camera mounted on a car. The camera is facing forward and slightly to the right.", + "pinhole_side_left": "The video is captured from a camera mounted on a car. The camera is facing to the left.", + "pinhole_side_right": "The video is captured from a camera mounted on a car. The camera is facing to the right.", +} + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") + parser.add_argument("--dataset_path", type=str, default="datasets/waymo", help="Root path to the dataset") + parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument( + "--pretrained_model_name_or_path", type=str, default="google-t5/t5-11b", help="T5 model name or the local path" + ) + parser.add_argument("--prompt", type=str, default="A video of sks teal robot.", help="Text prompt for the dataset") + parser.add_argument("--cache_dir", type=str, default="checkpoints", help="Directory to cache the T5 model") + return parser.parse_args() + + +def init_t5( + pretrained_model_name_or_path: str = "google-t5/t5-11b", max_length: int = 512, cache_dir: str = "~/.cache" +) -> Tuple[T5TokenizerFast, T5EncoderModel]: + """Initialize and return the T5 tokenizer and text encoder.""" + tokenizer = T5TokenizerFast.from_pretrained( + pretrained_model_name_or_path, model_max_length=max_length, cache_dir=cache_dir + ) + text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) + text_encoder.to("cuda") + text_encoder.eval() + return tokenizer, text_encoder + + +@torch.inference_mode() +def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512) -> list: + """ + Encode a batch of text prompts to a batch of T5 embeddings. + Parameters: + tokenizer: T5 embedding tokenizer. + encoder: T5 embedding text encoder. + prompts: A batch of text prompts. + max_length: Sequence length of text embedding (defaults to 512). + """ + + batch_encoding = tokenizer.batch_encode_plus( + prompts, + return_tensors="pt", + truncation=True, + padding="max_length", + max_length=max_length, + return_length=True, + return_offsets_mapping=False, + ) + + # We expect all the processing is done on GPU. + input_ids = batch_encoding.input_ids.cuda() + attn_mask = batch_encoding.attention_mask.cuda() + + outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = outputs.last_hidden_state + + lengths = attn_mask.sum(dim=1).cpu() + for batch_id in range(encoded_text.shape[0]): + encoded_text[batch_id][lengths[batch_id] :] = 0 + + encoded_text = encoded_text.cpu().numpy().astype(np.float16) + encoded_text = encoded_text[:, :max_length] + + # trim zeros to save space + encoded_text = [encoded_text[batch_id][: lengths[batch_id]] for batch_id in range(encoded_text.shape[0])] + + return encoded_text + + +def main(args) -> None: + videos_dir = os.path.join(args.dataset_path, "videos") + + metas_dir = os.path.join(args.dataset_path, "metas") + os.makedirs(metas_dir, exist_ok=True) + metas_list = [ + os.path.join(metas_dir, viewname, filename.replace(".mp4", ".txt")) + for viewname in sorted(os.listdir(videos_dir)) + for filename in sorted(os.listdir(videos_dir + "/" + viewname)) + if filename.endswith(".mp4") + ] + + # Write txt files to match other dataset formats. + for meta_filename in metas_list: + if not os.path.exists(meta_filename): + with open(meta_filename, "w") as fp: + fp.write(args.prompt) + + t5_xxl_dir = os.path.join(args.dataset_path, "t5_xxl") + os.makedirs(t5_xxl_dir, exist_ok=True) + + # Initialize T5 + tokenizer, text_encoder = init_t5(cache_dir=args.cache_dir) + + # Extract T5 embeddings for prefix prompt + for view_name, prefix_prompt in PREFIX_PROMPTS.items(): + t5_xxl_filename = os.path.join(args.dataset_path, "cache", f"prefix_t5_embeddings_{view_name}.pickle") + os.makedirs(os.path.dirname(t5_xxl_filename), exist_ok=True) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prefix_prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + for meta_filename in metas_list: + t5_xxl_filename = os.path.join( + t5_xxl_dir, meta_filename.split("/")[-2], os.path.basename(meta_filename).replace(".txt", ".pickle") + ) + os.makedirs(os.path.dirname(t5_xxl_filename), exist_ok=True) + if os.path.exists(t5_xxl_filename): + # Skip if the file already exists + continue + + with open(meta_filename, "r") as fp: + prompt = fp.read().strip() + + # Compute T5 embeddings + encoded_text = encode_for_batch(tokenizer, text_encoder, [prompt]) + + # Save T5 embeddings as pickle file + with open(t5_xxl_filename, "wb") as fp: + pickle.dump(encoded_text, fp) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/ip_header.py b/scripts/ip_header.py new file mode 100644 index 0000000000000000000000000000000000000000..f139c36ed77da543f9006fe2adecd080686f118c --- /dev/null +++ b/scripts/ip_header.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys + +import termcolor + +parser = argparse.ArgumentParser(description="Cosmos IP header checker/fixer") +parser.add_argument("--fix", action="store_true", help="apply the fixes instead of checking") +args, files_to_check = parser.parse_known_args() + + +def get_header(ext: str = "py", old: str | bool = False) -> list[str]: + header = [ + "SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.", + "SPDX-License-Identifier: Apache-2.0", + "", + 'Licensed under the Apache License, Version 2.0 (the "License");', + "you may not use this file except in compliance with the License.", + "You may obtain a copy of the License at", + "", + "http://www.apache.org/licenses/LICENSE-2.0", + "", + "Unless required by applicable law or agreed to in writing, software", + 'distributed under the License is distributed on an "AS IS" BASIS,', + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "See the License for the specific language governing permissions and", + "limitations under the License.", + ] + if ext == ".py" and old: + if old == "single": + header = ["'''"] + header + ["'''"] + elif old == "double": + header = ['"""'] + header + ['"""'] + else: + raise NotImplementedError + elif ext in (".py", ".yaml"): + header = [("# " + line if line else "#") for line in header] + elif ext in (".c", ".cpp", ".cu", ".h", ".cuh"): + header = ["/*"] + [(" * " + line if line else " *") for line in header] + [" */"] + else: + raise NotImplementedError + return header + + +def apply_file(file: str, results: dict[str, int], fix: bool = False) -> None: + if file.endswith("__init__.py"): + return + ext = os.path.splitext(file)[1] + content = open(file).read().splitlines() + header = get_header(ext=ext) + if fix: + if _check_header(content, header): + return + print(f"fixing: {file}") + while len(content) > 0 and not content[0]: + content.pop(0) + content = header + [""] + content + with open(file, "w") as file_obj: + for line in content: + file_obj.write(line + "\n") + else: + if not _check_header(content, header): + bad_header = colorize("BAD HEADER", color="red", bold=True) + print(f"{bad_header}: {file}") + results[file] = 1 + else: + results[file] = 0 + + +def traverse_directory(path: str, results: dict[str, int], fix: bool = False, substrings_to_skip=[]) -> None: + files = os.listdir(path) + for file in files: + full_path = os.path.join(path, file) + if os.path.isdir(full_path): + traverse_directory(full_path, results, fix=fix, substrings_to_skip=substrings_to_skip) + elif os.path.isfile(full_path): + ext = os.path.splitext(file)[1] + to_skip = any(substr in full_path for substr in substrings_to_skip) + if not to_skip and ext in (".py", ".yaml", ".c", ".cpp", ".cu", ".h", ".cuh"): + apply_file(full_path, results, fix=fix) + else: + raise NotImplementedError + + +def _check_header(content: list[str], header: list[str]) -> bool: + if content[: len(header)] != header: + return False + + i = len(header) + blank_line_count = 0 + + while i < len(content) and content[i].strip() == "": + blank_line_count += 1 + i += 1 + + # Allow at most two blank lines + if blank_line_count > 2: + return False + + # Must have at least one non-empty line after the blank lines + return i < len(content) + + +def colorize(x: str, color: str, bold: bool = False) -> str: + return termcolor.colored(str(x), color=color, attrs=("bold",) if bold else None) # type: ignore + + +if __name__ == "__main__": + if not files_to_check: + files_to_check = [ + "cosmos_predict1/auxiliary", + "cosmos_predict1/diffusion", + "cosmos_predict1/callbacks", + "cosmos_predict1/checkpointer", + "cosmos_predict1/autoregressive", + "cosmos_predict1/tokenizer", + "cosmos_predict1/utils", + ] + + for file in files_to_check: + assert os.path.isfile(file) or os.path.isdir(file), f"{file} is neither a directory or a file!" + + substrings_to_skip = ["prompt_upsampler"] + results = dict() + for file in files_to_check: + if os.path.isfile(file): + apply_file(file, results, fix=args.fix) + elif os.path.isdir(file): + traverse_directory(file, results, fix=args.fix, substrings_to_skip=substrings_to_skip) + else: + raise NotImplementedError + + if any(results.values()): + sys.exit(1) diff --git a/scripts/merge_autoregressive_tp_checkpoints.py b/scripts/merge_autoregressive_tp_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..07cd6b4866c5160dc2a291cb031e4ec99e4abeb5 --- /dev/null +++ b/scripts/merge_autoregressive_tp_checkpoints.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model_config +from cosmos_predict1.autoregressive.utils.checkpoint import merge_tensor_parallel_state_dicts +from cosmos_predict1.utils import log + + +def merge_sharded_checkpoints(checkpoint_path, output_path, tensor_parallel_size, model_size, model_family): + assert checkpoint_path.endswith(".pt"), "Checkpoint path must end with .pt" + assert model_family == "cosmos", "Only cosmos model family is currently supported" + assert model_size == "4b", "Only 4B model size is currently supported" + model_config, _ = create_video2world_model_config( + model_ckpt_path=checkpoint_path, + model_family=model_family, + model_size=model_size, + tensor_model_parallel_size=tensor_parallel_size, + tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit", + ) + log.info(f"Merging sharded checkpoints from {checkpoint_path.replace('.pt', '_model_mp_*.pt')} into {output_path}") + + checkpoint_paths = [checkpoint_path.replace(".pt", f"_model_mp_{rank}.pt") for rank in range(tensor_parallel_size)] + for path in checkpoint_paths: + assert os.path.exists(path), f"Checkpoint path {path} does not exist" + log.info(f"Found checkpoint {path}") + sharded_state_dicts = [torch.load(path, map_location="cpu") for path in checkpoint_paths] + merged_state_dict = merge_tensor_parallel_state_dicts(sharded_state_dicts, model_config) + torch.save(merged_state_dict, output_path) + log.info(f"Merged checkpoint saved to {output_path}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Merge Cosmos-Predict1-4B autoregressive checkpoints") + parser.add_argument( + "--checkpoint_path", + "-c", + type=str, + required=True, + help="Path to the checkpoint to merge. Must end with .pt and be colocated with the sharded checkpoints ending in _model_mp_{rank}.pt", + ) + parser.add_argument("--output_path", "-o", type=str, required=True, help="Path to the output merged checkpoint") + parser.add_argument("--tensor_parallel_size", "-t", type=int, required=True, help="Tensor parallel size") + parser.add_argument("--model_size", "-s", type=str, required=True, help="Model size") + parser.add_argument("--model_family", "-f", type=str, required=False, default="cosmos", help="Model family") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + merge_sharded_checkpoints( + args.checkpoint_path, args.output_path, args.tensor_parallel_size, args.model_size, args.model_family + ) diff --git a/scripts/shard_autoregressive_base_checkpoints.py b/scripts/shard_autoregressive_base_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..4e77144cd013c63cfc1993c0fda8963c6857e092 --- /dev/null +++ b/scripts/shard_autoregressive_base_checkpoints.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch + +from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model_config +from cosmos_predict1.autoregressive.utils.checkpoint import obtain_tensor_parallel_state_dict +from cosmos_predict1.utils import log + + +def shard_checkpoint(checkpoint_path, tensor_parallel_size, model_size, model_family, target_backend="pytorch"): + assert checkpoint_path.endswith(".pt"), "Checkpoint path must end with .pt" + assert model_family == "cosmos", "Only cosmos model family is currently supported" + assert model_size == "4b", "Only 4B model size is currently supported" + model_config, _ = create_video2world_model_config( + model_ckpt_path=checkpoint_path, + model_family=model_family, + model_size=model_size, + tensor_model_parallel_size=tensor_parallel_size, + tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit", + ) + log.info(f"Sharding checkpoint {checkpoint_path} with {tensor_parallel_size} ranks") + checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True) + for tensor_parallel_rank in range(tensor_parallel_size): + shard = obtain_tensor_parallel_state_dict( + checkpoint, tensor_parallel_size, tensor_parallel_rank, model_config, target_backend=target_backend + ) + shard_path = checkpoint_path.replace(".pt", f"_model_mp_{tensor_parallel_rank}.pt") + log.info(f"Saving shard {shard_path}") + torch.save(shard, shard_path) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Shard NVIDIA Cosmos Predict1 autoregressive models") + parser.add_argument( + "--checkpoint_path", + "-c", + type=str, + required=True, + default="checkpoints/Cosmos-Predict1-4B/model.pt", + help="Path to the checkpoint to shard", + ) + parser.add_argument("--tensor_parallel_size", "-t", type=int, required=True, help="Number of tensor parallel ranks") + parser.add_argument("--target_backend", "-b", type=str, required=False, default="pytorch", help="Target backend") + parser.add_argument("--model_size", "-s", type=str, required=True, help="Model size") + parser.add_argument("--model_family", "-f", type=str, required=False, default="cosmos", help="Model family") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + shard_checkpoint( + args.checkpoint_path, args.tensor_parallel_size, args.model_size, args.model_family, args.target_backend + ) diff --git a/scripts/test_environment.py b/scripts/test_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..606447884e3701b8992a6444f9c42b7aef9c400e --- /dev/null +++ b/scripts/test_environment.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import os +import sys + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--training", + action="store_true", + help="Whether to check training-specific dependencies", + ) + return parser.parse_args() + + +def check_packages(package_list): + global all_success + for package in package_list: + try: + _ = importlib.import_module(package) + except Exception as e: + print(f"\033[91m[ERROR]\033[0m Package not successfully imported: \033[93m{package}\033[0m") + all_success = False + else: + print(f"\033[92m[SUCCESS]\033[0m {package} found") + + +args = parse_args() + +if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): + detected = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + print(f"\033[91m[ERROR]\033[0m Python 3.10+ is required. You have: \033[93m{detected}\033[0m") + sys.exit(1) + +if "CONDA_PREFIX" not in os.environ: + print("\033[93m[WARNING]\033[0m Cosmos should be run under a conda environment.") + +print("Attempting to import critical packages...") + +packages = [ + "torch", + "torchvision", + "diffusers", + "transformers", + "megatron.core", + "transformer_engine", +] +packages_training = [ + "apex.multi_tensor_apply", +] +all_success = True + +check_packages(packages) +if args.training: + check_packages(packages_training) + +if all_success: + print("-----------------------------------------------------------") + print("\033[92m[SUCCESS]\033[0m Cosmos environment setup is successful!")