roll-ai commited on
Commit
b6af722
·
verified ·
1 Parent(s): f985c3e

Upload 381 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +22 -0
  2. assets/demo_1.gif +3 -0
  3. assets/demo_2.gif +3 -0
  4. assets/demo_3.gif +3 -0
  5. assets/demo_dynamic.gif +3 -0
  6. assets/diffusion/000000.png +3 -0
  7. assets/diffusion/000001.png +3 -0
  8. assets/diffusion/000002.png +3 -0
  9. assets/diffusion/000003.png +3 -0
  10. assets/diffusion/000004.png +3 -0
  11. assets/diffusion/000005.png +3 -0
  12. assets/diffusion/000006.png +3 -0
  13. assets/diffusion/000007.png +3 -0
  14. assets/diffusion/000008.png +3 -0
  15. assets/diffusion/000009.png +3 -0
  16. assets/diffusion/000010.png +3 -0
  17. assets/diffusion/000011.png +3 -0
  18. assets/diffusion/000012.png +3 -0
  19. assets/diffusion/000013.png +3 -0
  20. assets/diffusion/000014.png +3 -0
  21. assets/diffusion/000015.png +3 -0
  22. checkpoints/.DS_Store +0 -0
  23. checkpoints/README.md +4 -0
  24. cosmos-predict1.yaml +29 -0
  25. cosmos_predict1/.DS_Store +0 -0
  26. cosmos_predict1/__init__.py +14 -0
  27. cosmos_predict1/autoregressive/__init__.py +14 -0
  28. cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py +352 -0
  29. cosmos_predict1/autoregressive/configs/__init__.py +14 -0
  30. cosmos_predict1/autoregressive/configs/base/__init__.py +14 -0
  31. cosmos_predict1/autoregressive/configs/base/callbacks.py +33 -0
  32. cosmos_predict1/autoregressive/configs/base/dataloader.py +72 -0
  33. cosmos_predict1/autoregressive/configs/base/dataset.py +39 -0
  34. cosmos_predict1/autoregressive/configs/base/model.py +318 -0
  35. cosmos_predict1/autoregressive/configs/base/model_config.py +718 -0
  36. cosmos_predict1/autoregressive/configs/base/model_parallel.py +33 -0
  37. cosmos_predict1/autoregressive/configs/base/optim.py +86 -0
  38. cosmos_predict1/autoregressive/configs/base/tokenizer.py +139 -0
  39. cosmos_predict1/autoregressive/configs/config.py +111 -0
  40. cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py +0 -0
  41. cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py +163 -0
  42. cosmos_predict1/autoregressive/configs/inference/inference_config.py +102 -0
  43. cosmos_predict1/autoregressive/configs/registry.py +89 -0
  44. cosmos_predict1/autoregressive/datasets/dataset_utils.py +173 -0
  45. cosmos_predict1/autoregressive/datasets/video_dataset.py +190 -0
  46. cosmos_predict1/autoregressive/diffusion_decoder/__init__.py +14 -0
  47. cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py +61 -0
  48. cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py +61 -0
  49. cosmos_predict1/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py +85 -0
  50. cosmos_predict1/autoregressive/diffusion_decoder/config/registry.py +118 -0
.gitattributes CHANGED
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo_1.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/demo_2.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/demo_3.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/demo_dynamic.gif filter=lfs diff=lfs merge=lfs -text
40
+ assets/diffusion/000000.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/diffusion/000001.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/diffusion/000002.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/diffusion/000003.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/diffusion/000004.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/diffusion/000005.png filter=lfs diff=lfs merge=lfs -text
46
+ assets/diffusion/000006.png filter=lfs diff=lfs merge=lfs -text
47
+ assets/diffusion/000007.png filter=lfs diff=lfs merge=lfs -text
48
+ assets/diffusion/000008.png filter=lfs diff=lfs merge=lfs -text
49
+ assets/diffusion/000009.png filter=lfs diff=lfs merge=lfs -text
50
+ assets/diffusion/000010.png filter=lfs diff=lfs merge=lfs -text
51
+ assets/diffusion/000011.png filter=lfs diff=lfs merge=lfs -text
52
+ assets/diffusion/000012.png filter=lfs diff=lfs merge=lfs -text
53
+ assets/diffusion/000013.png filter=lfs diff=lfs merge=lfs -text
54
+ assets/diffusion/000014.png filter=lfs diff=lfs merge=lfs -text
55
+ assets/diffusion/000015.png filter=lfs diff=lfs merge=lfs -text
56
+ cosmos_predict1/tokenizer/test_data/image.png filter=lfs diff=lfs merge=lfs -text
57
+ cosmos_predict1/tokenizer/test_data/video.mp4 filter=lfs diff=lfs merge=lfs -text
assets/demo_1.gif ADDED

Git LFS Details

  • SHA256: e6162366c56277d084b05a37c617e2994ba75285d421e203556dcff08128b32b
  • Pointer size: 133 Bytes
  • Size of remote file: 14.7 MB
assets/demo_2.gif ADDED

Git LFS Details

  • SHA256: e765e71d3016c6e314b6403f82313a1df42f68f6fb0f9416f197d82e0710f27e
  • Pointer size: 133 Bytes
  • Size of remote file: 10.6 MB
assets/demo_3.gif ADDED

Git LFS Details

  • SHA256: 8c4cf4a4bf62daf03b25ac66c2c3693adbf7cd459e55d3481a65a9ff4a9d09d9
  • Pointer size: 133 Bytes
  • Size of remote file: 35.3 MB
assets/demo_dynamic.gif ADDED

Git LFS Details

  • SHA256: 174faba45ae701eaa432dd14de1297c0479b6c0b832adbc211cbb529fbec6c61
  • Pointer size: 133 Bytes
  • Size of remote file: 24.5 MB
assets/diffusion/000000.png ADDED

Git LFS Details

  • SHA256: b7e6eab7548c2ede900f8b504a5cef981e0cd0ec38af90dbea3f0db860e002c3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.33 MB
assets/diffusion/000001.png ADDED

Git LFS Details

  • SHA256: abe310078829c9e1375ac30c7c270c84c8f68a09f3857bd35c7a5754f3326151
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
assets/diffusion/000002.png ADDED

Git LFS Details

  • SHA256: 7ad89b53e9fafed0d8eefd1cfc7cc4889c5d2f510ed32d5247c5adab4cb0c622
  • Pointer size: 131 Bytes
  • Size of remote file: 789 kB
assets/diffusion/000003.png ADDED

Git LFS Details

  • SHA256: 22f39915f1b277e70683befbc18ac5859c65c3d389e4dbb5127a539a411fec54
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
assets/diffusion/000004.png ADDED

Git LFS Details

  • SHA256: e2f957208849c0f86b89545734bb7b243868b574554cb6aeed248b04e7234ad4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
assets/diffusion/000005.png ADDED

Git LFS Details

  • SHA256: 267f6ae47d0e2aebda89fac5416bc0915855043131d0d8d8a4fc9506cabd4681
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
assets/diffusion/000006.png ADDED

Git LFS Details

  • SHA256: 4b6fd098366bcd54bd21a5707ae6d9f78d74c2eefcfbb6919569c0d1741d837f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
assets/diffusion/000007.png ADDED

Git LFS Details

  • SHA256: 334733b7428f9521e625a8b310770fbba3e4616ccbe0af625d07e2b065e6e9ad
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
assets/diffusion/000008.png ADDED

Git LFS Details

  • SHA256: 7eae1abb3343c1e11f4e42172eba85eeed0fb2a5f7701a42e5003cf84f1696cd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.68 MB
assets/diffusion/000009.png ADDED

Git LFS Details

  • SHA256: 2a5c5711d41f56bb307ef6020d0dffec9ce2297bda9ef9ae465237d8347adb34
  • Pointer size: 131 Bytes
  • Size of remote file: 603 kB
assets/diffusion/000010.png ADDED

Git LFS Details

  • SHA256: e4d32f1d1c6d427e421d6f4478d4c2c697cb0406a18ecc3b8ebeeb2a0cbba7f5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
assets/diffusion/000011.png ADDED

Git LFS Details

  • SHA256: e352d7435d3b313fcc47efd9bd0dc6e0dd5d5e8af8c50e965c57987bee1c94ec
  • Pointer size: 131 Bytes
  • Size of remote file: 944 kB
assets/diffusion/000012.png ADDED

Git LFS Details

  • SHA256: b672d43521890b2852976a0c12828ad16b9288277efff6c41189dc0c04c9c6e1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
assets/diffusion/000013.png ADDED

Git LFS Details

  • SHA256: eab3a655213eede094889bab94313e1cef142b811429bee9e0f3420c2b013105
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
assets/diffusion/000014.png ADDED

Git LFS Details

  • SHA256: eb014db53082677aca35a3fc27daa1f306452c5cb7130a4ed6468cae144a0b63
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
assets/diffusion/000015.png ADDED

Git LFS Details

  • SHA256: a6ac0d4e7eb6d4dbc3ae997fafc28721b716db092aaa52ede11e4d87b3e9b20d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
checkpoints/.DS_Store ADDED
Binary file (6.15 kB). View file
 
checkpoints/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ ### Checkpoint directory
3
+
4
+ Model checkpoints will be downloaded to this directory.
cosmos-predict1.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # conda env create --file cosmos-predict1.yaml
17
+ name: cosmos-predict1
18
+ channels:
19
+ - conda-forge
20
+ dependencies:
21
+ - python=3.10
22
+ - pip=25.0
23
+ - cmake
24
+ - ninja
25
+ - gcc=12.4.0
26
+ - gxx=12.4.0
27
+ - cuda=12.4
28
+ - cuda-nvcc=12.4
29
+ - cuda-toolkit=12.4
cosmos_predict1/.DS_Store ADDED
Binary file (8.2 kB). View file
 
cosmos_predict1/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import math
18
+ import os
19
+ from typing import Optional
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torchvision
24
+ import torchvision.transforms.functional as torchvision_F
25
+ import wandb
26
+ from einops import rearrange
27
+ from megatron.core import parallel_state
28
+ from torch.distributed import get_process_group_ranks
29
+
30
+ from cosmos_predict1.autoregressive.utils.parallel import (
31
+ broadcast_data_batch_in_tp_cp_group,
32
+ gather_batch_from_cp_ranks,
33
+ get_batch_on_this_cp_rank,
34
+ )
35
+ from cosmos_predict1.callbacks.every_n import EveryN
36
+ from cosmos_predict1.utils import distributed, log, misc
37
+ from cosmos_predict1.utils.model import Model
38
+ from cosmos_predict1.utils.trainer import Trainer
39
+
40
+
41
+ def resize_image(image: torch.Tensor, resize_factor=0.5) -> torch.Tensor:
42
+ _, _, h, w = image.shape
43
+ new_h, new_w = int(resize_factor * h), int(resize_factor * w)
44
+ return torchvision_F.resize(image, (new_h, new_w))
45
+
46
+
47
+ class VideoSamplingTeacherForcing(EveryN):
48
+ def __init__(
49
+ self,
50
+ every_n: int,
51
+ step_size: int = 1,
52
+ video_latent_shape: list = [6, 24, 40],
53
+ num_frames_to_display: int = 4,
54
+ save_folder: Optional[str] = None,
55
+ num_file_to_log: int = 8,
56
+ ):
57
+ r"""
58
+ This callback enables us to perform teacher forcing inference on the training data.
59
+ By teacher forcing, we mean providing ground truth video tokens as inputs, and simply asking the model
60
+ to predict the next tokens. The predicted next tokens are then visualized. This does not perform
61
+ autoregressive sampling.
62
+ We also upload the downsampled video frames to wandb. Downsampling is needed for wandb to work fast.
63
+
64
+ Args:
65
+ every_n (int): Call this callback every_n steps
66
+ step_size (int): Number of steps taken for gradient accumulation. Global iteration number is
67
+ iteration // self.step_size
68
+ video_latent_shape (list): Shape of the video latent
69
+ num_frames_to_display (int): Number of frames to subsample for displaying in wandb
70
+ save_folder (str): Name of the local folder to save the video
71
+ num_file_to_log (int): Number of files to upload to wandb
72
+ """
73
+ super().__init__(every_n, step_size)
74
+ self.save_folder = save_folder if save_folder else self.__class__.__name__
75
+ self.video_latent_shape = video_latent_shape
76
+ self.num_frames_to_display = num_frames_to_display
77
+ self.num_file_to_log = num_file_to_log
78
+ self.rank = distributed.get_rank()
79
+
80
+ def on_train_start(self, model: Model, iteration: int = 0) -> None:
81
+ config_job = self.config.job
82
+ self.local_dir = f"{config_job.path_local}/{self.save_folder}"
83
+ if self.rank == 0:
84
+ os.makedirs(self.local_dir, exist_ok=True)
85
+ log.info(f"Video Teacher-Forcing Callback: local_dir: {self.local_dir}")
86
+
87
+ @torch.inference_mode()
88
+ def every_n_impl(
89
+ self,
90
+ trainer: Trainer,
91
+ model: Model,
92
+ data_batch: dict[str, torch.Tensor],
93
+ output_batch: dict[str, torch.Tensor],
94
+ loss: torch.Tensor,
95
+ iteration: int,
96
+ ) -> None:
97
+ # Tokenize the data
98
+
99
+ broadcast_data_batch_in_tp_cp_group(data_batch)
100
+
101
+ input_vid = data_batch[model.tokenizer.tokenizer_config.video_tokenizer.data_key]
102
+
103
+ dataset_name = data_batch.get("dataset_name", None)
104
+ if dataset_name is not None and dataset_name.startswith("image"):
105
+ # we disable the callback if the input video is an image batch
106
+ log.info(f"dataset_name is {dataset_name}, skip this callback")
107
+ return
108
+
109
+ # get the caption
110
+ captions = data_batch.get("caption", None)
111
+
112
+ # get the context embedding and mask
113
+ context = data_batch.get("context", None)
114
+ context_mask = data_batch.get("context_mask", None)
115
+ if context is not None:
116
+ context = misc.to(context, "cuda").detach().clone()
117
+ if context_mask is not None:
118
+ context_mask = misc.to(context_mask, "cuda").detach().clone()
119
+ # get the action
120
+ action = data_batch.get("action", None)
121
+ if action is not None:
122
+ action = misc.to(action, "cuda").detach().clone()
123
+
124
+ # Input tokens
125
+ tokens, _ = model.tokenizer.tokenize(data_batch)
126
+ tokens = misc.to(tokens, "cuda").detach().clone()
127
+ skip_save_file = False
128
+ if parallel_state.get_context_parallel_world_size() > 1:
129
+ cp_group = parallel_state.get_context_parallel_group()
130
+ if self.rank != min(get_process_group_ranks(cp_group)):
131
+ skip_save_file = True
132
+ tokens = get_batch_on_this_cp_rank(tokens)
133
+ if parallel_state.get_tensor_model_parallel_world_size() > 1:
134
+ # Turn on TP
135
+ tp_group = parallel_state.get_tensor_model_parallel_group()
136
+ if self.rank != min(get_process_group_ranks(tp_group)):
137
+ skip_save_file = True
138
+ tokens_encoded_in_train = output_batch["encode_tokens"].detach()
139
+ percent_token_diff = (tokens != tokens_encoded_in_train).float().mean()
140
+ percent_token_diff = distributed.dist_reduce_tensor(percent_token_diff)
141
+
142
+ input_tokens = tokens
143
+
144
+ num_tokens_to_generate = np.prod(self.video_latent_shape)
145
+
146
+ # Do a forward pass
147
+ logits = model.model.forward(
148
+ tokens,
149
+ input_pos=None,
150
+ context=context,
151
+ context_mask=context_mask,
152
+ action=action,
153
+ )
154
+ if parallel_state.get_context_parallel_world_size() > 1:
155
+ logits = gather_batch_from_cp_ranks(logits)
156
+ input_tokens = gather_batch_from_cp_ranks(input_tokens)
157
+
158
+ # Start position for video tokens in the vocabulary
159
+ video_token_start = self.config.model.tokenizer_config.video_tokenizer.tokenizer_offset
160
+ video_vocab_size = self.config.model.tokenizer_config.video_tokenizer.vocab_size
161
+
162
+ # Clipping logits only to video tokens. We remove the text vocab predictions.
163
+ # This will ensure that the video tokens only correspond to the video part of the vocabulary.
164
+ logits = logits[:, :, video_token_start : video_token_start + video_vocab_size]
165
+
166
+ # Sample with argmax token. This should be good for teacher forcing experiment.
167
+ logits = logits.contiguous()
168
+ generations = torch.argmax(logits, dim=-1)
169
+
170
+ # For each video in the batch, subsample frames for display
171
+ batch_size = input_tokens.shape[0]
172
+ out_frames = []
173
+ out_videos_gen = []
174
+ out_videos_rec = []
175
+ out_videos_gt = []
176
+ # log the accuracy of teacher-forcing
177
+ acc = []
178
+ loss_list = []
179
+
180
+ for sample_num in range(batch_size):
181
+ # Subsample the generations to the video part.
182
+ # This corresponds to the part from begin of video to end of video.
183
+ bov_token = model.tokenizer.video_special_tokens["<|begin_of_video|>"]
184
+ bov_index = input_tokens[sample_num] == bov_token
185
+ use_special_token = sum(bov_index) != 0
186
+ if use_special_token:
187
+ bov_index = bov_index.nonzero().item()
188
+ # generations: <bov> real_token1 real_token2, ... real_token7680; total 7680
189
+ # gen_video_tokens: real_token1 real_token2, ..., real_token7680; total 7680
190
+ # for vis: real_token1 real_token2, ..., real_token7680; total 7680
191
+ # for accuracy: real_token1 real_token2, ..., real_token7680; total 7680
192
+ gen_video_tokens = generations[sample_num][bov_index : bov_index + num_tokens_to_generate]
193
+ gen_video_tokens_vis = gen_video_tokens
194
+ gen_video_tokens_acc = gen_video_tokens
195
+ logits_loss = logits[sample_num][bov_index : bov_index + num_tokens_to_generate]
196
+ else:
197
+ # generations: real_token1 real_token2, ... real_token7680
198
+ # gen_video_tokens: real_token2 real_token3, ..., real_token7680; total 7679
199
+ # We need different tokens for vis and accuracy compute
200
+ # for acc: real_token2 real_token3, ..., real_token7680; total 7679
201
+ # for vis: pad_token (real_token2, ..., real_token7680); total 1 + 7679
202
+ gen_video_tokens = generations[sample_num][
203
+ : num_tokens_to_generate - 1
204
+ ] # remove the last token since there is no gt
205
+ # Since the first token is not predicted, we need to add the gt first token to make sure the shape is correct
206
+ gen_video_tokens_vis = torch.cat([input_tokens[sample_num][0:1], gen_video_tokens])
207
+ gen_video_tokens_acc = gen_video_tokens
208
+ logits_loss = logits[sample_num][: num_tokens_to_generate - 1]
209
+
210
+ # Rearrange the video to a spatial tensor
211
+ gen_video_tokens_vis_BTHW = rearrange(
212
+ gen_video_tokens_vis.unsqueeze(0),
213
+ "B (T H W) -> B T H W",
214
+ T=self.video_latent_shape[0],
215
+ H=self.video_latent_shape[1],
216
+ W=self.video_latent_shape[2],
217
+ )
218
+
219
+ # for real videos, we need to skip the bov and eov tokens for decoding
220
+ if use_special_token:
221
+ # input_tokens: <bov> real_token1 real_token2 ... <eov> <eov> ...
222
+ # real_video_tokens: real_token1 real_token2 ... real_token7680; total 7680
223
+ # for vis: real_token1 real_token2 ... real_token7680; total 7680
224
+ # for accuracy: real_token1 real_token2 ... real_token7680; total 7680; we include real_token1 since the output prediction also includes it, see gen_video_tokens_acc above
225
+ real_video_tokens = (
226
+ input_tokens[sample_num][bov_index + 1 : bov_index + num_tokens_to_generate + 1] - video_token_start
227
+ )
228
+ real_video_tokens_vis = real_video_tokens
229
+ real_video_tokens_acc = real_video_tokens
230
+ else:
231
+ # input_tokens: real_token1 real_token2,... real_token7680; total 7680
232
+ # real_video_tokens: real_token1 real_token2,... real_token7680; total 7680
233
+ # for acc: gt start from real_token2, real_token3; total 7679, remove the first token since it is not predicted
234
+ # for vis: gt start from real_token1, real_token2; total 7680
235
+ real_video_tokens = input_tokens[sample_num][:num_tokens_to_generate] - video_token_start
236
+ real_video_tokens_vis = real_video_tokens
237
+ real_video_tokens_acc = real_video_tokens[1:].flatten()
238
+
239
+ real_video_tokens_vis_BTHW = rearrange(
240
+ real_video_tokens_vis.unsqueeze(0),
241
+ "B (T H W) -> B T H W",
242
+ T=self.video_latent_shape[0],
243
+ H=self.video_latent_shape[1],
244
+ W=self.video_latent_shape[2],
245
+ )
246
+ # Calculate accuracy
247
+ correct_predictions = (gen_video_tokens_acc == real_video_tokens_acc).float()
248
+ labels = real_video_tokens_acc.clone()
249
+
250
+ if model.config.ignore_first_num_tokens > 0:
251
+ labels[: model.config.ignore_first_num_tokens] = model.tokenizer.ignore_index
252
+ select_index = labels != model.tokenizer.ignore_index
253
+ correct_predictions = correct_predictions[select_index]
254
+
255
+ loss = torch.nn.functional.cross_entropy(
256
+ logits_loss, labels, ignore_index=model.tokenizer.ignore_index, reduction="none"
257
+ )
258
+ acc.append(correct_predictions.mean() * 100.0)
259
+ loss_list.append(loss.mean())
260
+
261
+ # Decode the predicted latents
262
+ if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0:
263
+ vid_decoded = model.tokenizer.video_tokenizer.decode(gen_video_tokens_vis_BTHW.cuda())
264
+ else:
265
+ vid_decoded = model.tokenizer.video_tokenizer.decode_with_overlap(
266
+ gen_video_tokens_vis_BTHW.cuda(),
267
+ temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap,
268
+ )
269
+ # normalize decoded images from [-1, 1] to [0, 1], and clip value
270
+ vid_decoded = (vid_decoded * 0.5 + 0.5).clamp_(0, 1)
271
+ vid_decoded = vid_decoded[0]
272
+
273
+ # Decode the GT latents
274
+ if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0:
275
+ vid_rec = model.tokenizer.video_tokenizer.decode(real_video_tokens_vis_BTHW.cuda())
276
+ else:
277
+ vid_rec = model.tokenizer.video_tokenizer.decode_with_overlap(
278
+ real_video_tokens_vis_BTHW.cuda(),
279
+ temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap,
280
+ )
281
+ # normalize decoded image from [-1, 1] to [0, 1], and clip value
282
+ vid_rec = (vid_rec * 0.5 + 0.5).clamp_(0, 1)
283
+ vid_rec = vid_rec[0]
284
+
285
+ vid_input = input_vid[sample_num] # [-1, 1], input_vid shape: [B, C, L, H, W]
286
+ vid_input = (vid_input * 0.5 + 0.5).clamp_(0, 1).cuda() # Convert to [0, 1], [C, L, H, W]
287
+
288
+ # Subsample real and generated video frames
289
+ input_video_frames = vid_input.transpose(0, 1) # [L, C, H, W]
290
+ rec_video_frames = vid_rec.transpose(0, 1)
291
+ gen_video_frames = vid_decoded.transpose(0, 1)
292
+ out_videos_gen.append(gen_video_frames)
293
+ out_videos_rec.append(rec_video_frames)
294
+ out_videos_gt.append(input_video_frames)
295
+
296
+ stride = math.ceil(rec_video_frames.shape[0] / self.num_frames_to_display)
297
+
298
+ input_video_frames_subsampled = resize_image(input_video_frames[0::stride], resize_factor=0.5)
299
+ input_video_frames_subsampled = torchvision.utils.make_grid(
300
+ input_video_frames_subsampled, nrow=input_video_frames_subsampled.shape[0]
301
+ )
302
+
303
+ gt_video_frames_subsampled = resize_image(rec_video_frames[0::stride], resize_factor=0.5)
304
+ gt_video_frames_subsampled = torchvision.utils.make_grid(
305
+ gt_video_frames_subsampled, nrow=gt_video_frames_subsampled.shape[0]
306
+ )
307
+ gen_video_frames_subsampled = resize_image(gen_video_frames[0::stride], resize_factor=0.5)
308
+ gen_video_frames_subsampled = torchvision.utils.make_grid(
309
+ gen_video_frames_subsampled, nrow=gen_video_frames_subsampled.shape[0]
310
+ )
311
+
312
+ out_frames.append(input_video_frames_subsampled)
313
+ out_frames.append(gt_video_frames_subsampled)
314
+ out_frames.append(gen_video_frames_subsampled)
315
+
316
+ scaled_num_rank_to_log = (
317
+ self.num_file_to_log
318
+ * parallel_state.get_context_parallel_world_size()
319
+ * parallel_state.get_tensor_model_parallel_world_size()
320
+ )
321
+ if self.rank < scaled_num_rank_to_log and not skip_save_file:
322
+ local_path = f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_{self.rank:04d}.jpg"
323
+ out_image_grid = torchvision.utils.make_grid(out_frames, nrow=1, padding=0, normalize=False)
324
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
325
+ torchvision.utils.save_image(out_image_grid, local_path)
326
+
327
+ # Log to wandb
328
+ avg_acc = distributed.dist_reduce_tensor(torch.stack(acc).mean()).item()
329
+ avg_loss = distributed.dist_reduce_tensor(torch.stack(loss_list).mean()).item()
330
+ log_info = ""
331
+ if "acc" in output_batch:
332
+ log_info = f"train acc: {(output_batch['acc'].mean().item()):.6f}%"
333
+ if percent_token_diff is not None:
334
+ log_info += f"; percent_token_diff_train_val: {percent_token_diff.item() * 100:.6f}%"
335
+ log.info(
336
+ f"Eval iteration {iteration} teacher-forcing accuracy: {avg_acc:.6f}%, loss: {avg_loss:.4f}; {log_info}"
337
+ )
338
+ if self.rank == 0 and wandb.run:
339
+ local_files = glob.glob(f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_*.jpg")
340
+ local_files = sorted(local_files)[: self.num_file_to_log]
341
+ if captions is None:
342
+ captions = ["vid_frames_teacher_forcing"] * len(local_files)
343
+ for local_path, caption in zip(local_files, captions):
344
+ wandb.log(
345
+ {"frames": [wandb.Image(local_path, caption=caption)]},
346
+ step=iteration,
347
+ )
348
+
349
+ wandb.log({"eval/teacher_forcing_acc": avg_acc}, step=iteration)
350
+ wandb.log({"eval/teacher_forcing_loss": avg_loss}, step=iteration)
351
+ if percent_token_diff is not None:
352
+ wandb.log({"eval/percent_token_diff_train_val": percent_token_diff.item() * 100}, step=iteration)
cosmos_predict1/autoregressive/configs/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/configs/base/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/configs/base/callbacks.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from cosmos_predict1.autoregressive.callbacks.video_sampling_teacher_forcing import VideoSamplingTeacherForcing
17
+ from cosmos_predict1.callbacks.grad_clip import GradClip
18
+ from cosmos_predict1.utils.callback import ProgressBarCallback
19
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
20
+
21
+ BASIC_CALLBACKS = dict(
22
+ progress_bar=L(ProgressBarCallback)(),
23
+ grad_clip=L(GradClip)(clip_norm=1.0, fsdp_enabled="${model.model_config.fsdp_enabled}", model_key="model"),
24
+ )
25
+
26
+ VIDEO_TEACHER_FORCING_CALLBACK = dict(
27
+ vid_sampling_tf=L(VideoSamplingTeacherForcing)(
28
+ every_n=500,
29
+ video_latent_shape="${model.model_config.video_latent_shape}",
30
+ num_frames_to_display=4,
31
+ save_folder="video_sampling_teacher_forcing",
32
+ )
33
+ )
cosmos_predict1/autoregressive/configs/base/dataloader.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from megatron.core import parallel_state
17
+ from torch.utils.data import DataLoader, DistributedSampler
18
+
19
+ from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig
20
+ from cosmos_predict1.autoregressive.datasets.video_dataset import VideoDataset
21
+ from cosmos_predict1.utils import log
22
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
23
+
24
+ DATALOADER_OPTIONS = {}
25
+
26
+
27
+ def get_sampler(dataset):
28
+ return DistributedSampler(
29
+ dataset,
30
+ num_replicas=parallel_state.get_data_parallel_world_size(),
31
+ rank=parallel_state.get_data_parallel_rank(),
32
+ shuffle=True,
33
+ seed=0,
34
+ )
35
+
36
+
37
+ def dataloader_register(key):
38
+ log.info(f"registering dataloader {key}...")
39
+
40
+ def decorator(func):
41
+ DATALOADER_OPTIONS[key] = func
42
+ return func
43
+
44
+ return decorator
45
+
46
+
47
+ @dataloader_register("tealrobot_video")
48
+ def get_tealrobot_video(
49
+ batch_size: int = 1,
50
+ dataset_dir: str = "datasets/cosmos_nemo_assets/videos/",
51
+ sequence_interval: int = 1,
52
+ num_frames: int = 33,
53
+ video_size: list[int, int] = [640, 848],
54
+ start_frame_interval: int = 1,
55
+ ):
56
+ dataset = L(VideoDataset)(
57
+ config=VideoDatasetConfig(
58
+ dataset_dir=dataset_dir,
59
+ sequence_interval=sequence_interval,
60
+ num_frames=num_frames,
61
+ video_size=video_size,
62
+ start_frame_interval=start_frame_interval,
63
+ )
64
+ )
65
+ return L(DataLoader)(
66
+ dataset=dataset,
67
+ sampler=L(get_sampler)(dataset=dataset),
68
+ batch_size=batch_size,
69
+ drop_last=True,
70
+ pin_memory=True,
71
+ num_workers=8,
72
+ )
cosmos_predict1/autoregressive/configs/base/dataset.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Dataset config class."""
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.utils.config import make_freezable
21
+
22
+
23
+ @make_freezable
24
+ @attrs.define(slots=False)
25
+ class VideoDatasetConfig:
26
+ """
27
+ Args:
28
+ dataset_dir (str): Base path to the dataset directory
29
+ sequence_interval (int): Interval between sampled frames in a sequence
30
+ num_frames (int): Number of frames to load per sequence
31
+ video_size (list): Target size [H,W] for video frames
32
+ start_frame_interval (int): Interval between starting frames of sequences
33
+ """
34
+
35
+ dataset_dir: str = "datasets/cosmos_nemo_assets/videos/"
36
+ sequence_interval: int = 1
37
+ num_frames: int = 33
38
+ video_size: list[int, int] = [640, 848]
39
+ start_frame_interval: int = 1
cosmos_predict1/autoregressive/configs/base/model.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig
21
+ from cosmos_predict1.utils import config
22
+
23
+ _ACTION_DIM = 8
24
+ from cosmos_predict1.utils.lazy_config import LazyDict
25
+
26
+
27
+ @attrs.define
28
+ class ModelConfig:
29
+ """
30
+ A class to hold model configuration arguments.
31
+
32
+ Args:
33
+ dim (int): The dimensionality of the input and output of each transformer block.
34
+ n_layers (int): Number of layers in the transformer.
35
+ n_heads (int): Number of attention heads.
36
+ n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
37
+ `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
38
+ head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
39
+ vocab_size (int): Vocabulary size.
40
+ ffn_hidden_size (int): Hidden size for feedforward network.
41
+ norm_eps (float): Epsilon value for normalization.
42
+ rope_theta (float): Theta value for rotary positional embeddings.
43
+ apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
44
+ max_batch_size (int): Maximum batch size for inference.
45
+ max_seq_len (int): Maximum sequence length for input text.
46
+ fuse_qkv (bool): Whether to fuse QKV in attention. Defaults to True.
47
+ causal_mask (bool): Whether to use causal mask. Defaults to True.
48
+ norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
49
+ precision (str): Data type for the model.
50
+ use_qk_normalization (bool): Whether to enable QK normalization.
51
+ tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1.
52
+ ckpt_dir (str): Checkpoint directory.
53
+ ckpt_path (str): Checkpoint path.
54
+ apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
55
+ yarn_scale (Optional[float]): Scale factor for YaRN.
56
+ yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
57
+ yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
58
+ original_seq_len (Optional[int]): Original sequence length.
59
+ vision_encoder (Optional[str]): Vision encoder name.
60
+ mm_projector (Optional[str]): Multi-modal projector name.
61
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
62
+ rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "3D".
63
+ pytorch_rope_version (Optional[str]): Version of the PyTorch RoPE implementation. Choices: "v1", "v2".
64
+ original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
65
+ pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
66
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3.
67
+ insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
68
+ insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
69
+ context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
70
+ num_video_frames (Optional[int]): Number of video frames.
71
+ video_height (Optional[int]): Raw video pixel height dimension.
72
+ video_width (Optional[int]): Raw video pixel width dimension.
73
+ video_latent_shape (Optional[list]): Video tokenizer output dimension, in (T,H,W).
74
+ """
75
+
76
+ dim: int = attrs.field(default=4096)
77
+ n_layers: int = attrs.field(default=32)
78
+ n_heads: int = attrs.field(default=32)
79
+ n_kv_heads: Optional[int] = attrs.field(default=8)
80
+ head_dim: Optional[int] = attrs.field(default=None)
81
+ vocab_size: int = attrs.field(default=128256)
82
+ ffn_hidden_size: int = attrs.field(default=14336)
83
+ norm_eps: float = attrs.field(default=1e-5)
84
+ rope_theta: float = attrs.field(default=500000)
85
+ apply_abs_pos_emb: bool = attrs.field(default=False)
86
+ max_batch_size: int = attrs.field(default=1)
87
+ max_seq_len: int = attrs.field(default=8192)
88
+ fuse_qkv: bool = attrs.field(default=False)
89
+ causal_mask: bool = attrs.field(default=True)
90
+ norm_type: str = attrs.field(default="rmsnorm")
91
+ precision: str = attrs.field(default="bfloat16")
92
+ use_qk_normalization: bool = False
93
+ tokenizer: Optional[TokenizerConfig] = None
94
+ tensor_model_parallel_size: int = attrs.field(default=1)
95
+ ckpt_dir: Optional[str] = attrs.field(default=None)
96
+ ckpt_path: Optional[str] = attrs.field(
97
+ default=None
98
+ ) # If not None, load the model from this path instead of ckpt_dir
99
+ apply_yarn: Optional[bool] = attrs.field(default=False)
100
+ yarn_scale: Optional[float] = attrs.field(default=None)
101
+ yarn_beta_fast: Optional[int] = attrs.field(default=None)
102
+ yarn_beta_slow: Optional[int] = attrs.field(default=None)
103
+ original_seq_len: Optional[int] = attrs.field(default=None)
104
+ vision_encoder: Optional[str] = attrs.field(default=None)
105
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
106
+ mm_projector: Optional[str] = attrs.field(default=None)
107
+ rope_dim: Optional[str] = attrs.field(default="1D")
108
+ pytorch_rope_version: Optional[str] = attrs.field(default="v2")
109
+ original_latent_shape: Optional[list] = None
110
+ pad_to_multiple_of: Optional[int] = None
111
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
112
+ insert_cross_attn: bool = False
113
+ insert_cross_attn_every_k_layers: int = 1
114
+ context_dim: Optional[int] = attrs.field(default=1024)
115
+ # For video training
116
+ num_video_frames: Optional[int] = None
117
+ # Raw video pixel dimension
118
+ video_height: Optional[int] = None
119
+ video_width: Optional[int] = None
120
+ # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
121
+ video_latent_shape: Optional[list] = None
122
+
123
+ def __getitem__(self, item):
124
+ return getattr(self, item)
125
+
126
+
127
+ @attrs.define
128
+ class TrainingModelConfig:
129
+ """
130
+ A class to hold model configuration arguments.
131
+
132
+ Args:
133
+ dim (int): The dimensionality of the input and output of each transformer block.
134
+ n_layers (int): Number of layers in the transformer.
135
+ n_heads (int): Number of attention heads.
136
+ n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
137
+ `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
138
+ head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
139
+ vocab_size (int): Vocabulary size.
140
+ multiple_of (int): Ensures the hidden layer size is a multiple of this value for SwiGLU activation.
141
+ ffn_dim_multiplier (Optional[float]): Multiplier for feedforward network dimension.
142
+ ffn_hidden_size (Optional[int]): Hidden size for feedforward network. If None, use ffn_dim_multiplier to compute it.
143
+ norm_eps (float): Epsilon value for normalization.
144
+ rope_theta (float): Theta value for rotary positional embeddings.
145
+ apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
146
+ max_batch_size (int): Maximum batch size for inference (determines KV cache size).
147
+ max_seq_len (int): Maximum sequence length for input text (determines KV cache size).
148
+ fuse_qkv (bool): Whether to fuse QKV in attention. Flag for the pytorch backend.
149
+ causal_mask (bool): Whether to use causal mask. Defaults to True.
150
+ flash_attn (bool): Whether to use Flash attention.
151
+ norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
152
+ backend (str): Backend for the model.
153
+ precision (str): Data type for the model.
154
+ ema (config.EMAConfig): Configuration for exponential moving average.
155
+ embedding_dropout(float): Dropout rate for the embedding layer.
156
+ attention_dropout(float): Dropout rate for attention.
157
+ hidden_dropout(float): Dropout after the attention and feed-forward layers (following TransformerEngine's
158
+ implementation in its TransformerLayer class).
159
+ use_qk_normalization (bool): Whether to enable QK normalization.
160
+ inference (bool): Whether the model is used for inference.
161
+ act_ckpt_enabled (bool): Whether to enable activation checkpointing.
162
+ fsdp_enabled (bool): Whether to enable FSDP.
163
+ fsdp (LazyDict): Configuration for FSDP.
164
+ ckpt_dir (str): Checkpoint directory.
165
+ ckpt_path (str): Checkpoint path.
166
+ cache_dir (str): Cache directory.
167
+ apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
168
+ yarn_scale (Optional[float]): Scale factor for YaRN.
169
+ yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
170
+ yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
171
+ original_seq_len (Optional[int]): Original sequence length.
172
+ depth_init (bool): If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the
173
+ total number of transformer blocks. Defaults to `True` (following the TorchTitan implementation of Llama3).
174
+ context_parallel_size (int): Context parallel size. Defaults to 1.
175
+ tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1.
176
+ sequence_parallel (bool): Whether to use sequence parallelism. Defaults to False.
177
+ set_parallel_mode (bool): It is a boolean flag used by TransformerEngine to handle Tensor Parallelism.
178
+ Essentially, it is equivalent to `tensor_model_parallel_size > 1`. Defaults to `False`.
179
+ attention_tp (bool): Whether to use tensor parallelism for attention layers.
180
+ mm_projector (Optional[str]): Multimodal projector used for vision-language modeling. Defaults to None.
181
+ Choices: "identity", "linear", "mlp", "mlp_downsample".
182
+ video_latent_shape (Optional[list]): Shape of the video latent tensor. [T, H, W]
183
+ image_latent_shape (Optional[list]): Shape of the image latent tensor. [H, W]
184
+ num_video_frames (Optional[int]): Number of video frames.
185
+ rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D".
186
+ pytorch_rope_version (Optional[str]): Version of the RoPE for the `pytorch` backend. "v1" is the Llama implementation, and "v2" is HuggingFace/TransformerEngine implementation.
187
+ original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
188
+ pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
189
+ peft_last_n_layers (Optional[int]): Number of last few layers to fine-tune in Parameter Efficient Fine-Tuning (PEFT). When this and peft_every_n_layers are both 0, it means all layers are fine-tuned (FFT).
190
+ peft_every_n_layers (Optional[int]): In Parameter Efficient Fine-Tuning (PEFT), every n layers are unfrozen and can be trained (in flamingo style). When this and peft_last_n_layers are both 0,
191
+ it means all layers are fine-tuned (FFT). For example, for a 40 layer model, n=8 means training layers 7, 15, 23, 31, 39, which includes the final layer.
192
+ It is advised to pick n such that the final layer is included.
193
+ freeze_vision_encoder (bool): Whether to freeze the vision encoder in vision-language model training. Defaults to False.
194
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
195
+ insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
196
+ insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
197
+ context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
198
+ finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn).
199
+ finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn).
200
+ use_action_condition (bool): Whether to use the robot action condition.
201
+ action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp".
202
+ action_dim (Optional[int]): The dimensionality of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]).
203
+ action_embedding_dim (Optional[int]): The dimensionality of the robot action embedding.
204
+ group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal".
205
+ sync_1d_parameters (bool): Whether to synchronize layernorm parameters (1D) across tensor parallel ranks (default True).
206
+ Note: this is to ensure all TP-ranks have the same layernorm parameters.
207
+ z_loss_coeff (float): The coefficient for the z-loss.
208
+ insert_medusa_head (bool): Whether to insert the Medusa head.
209
+ ft_medusa_option (str): Options on which layers to finetune, choices like:
210
+ "fft": fully fine-tune both medusa heads and all LLM backbone;
211
+ "head": fine-tune medusa heads;
212
+ "head_out": fine-tune medusa heads, and the output layer;
213
+ "head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone.
214
+ medusa_num_heads (int): Number of heads in the Medusa head.
215
+ medusa_num_layers (int): Number of layers in the Medusa head.
216
+ medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1.
217
+ zero_init_cross_attn_proj (bool): Whether to initialize the cross-attn proj layer with zeros (default False).
218
+ concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False).
219
+ """
220
+
221
+ dim: int = attrs.field(default=4096)
222
+ n_layers: int = attrs.field(default=32)
223
+ n_heads: int = attrs.field(default=32)
224
+ n_kv_heads: Optional[int] = attrs.field(default=8)
225
+ head_dim: Optional[int] = attrs.field(default=None)
226
+ vocab_size: int = attrs.field(default=128256)
227
+ multiple_of: int = attrs.field(default=1024) # make SwiGLU hidden layer size multiple of large power of 2
228
+ ffn_dim_multiplier: Optional[float] = attrs.field(default=1.3)
229
+ ffn_hidden_size: Optional[int] = attrs.field(default=None)
230
+ norm_eps: float = attrs.field(default=1e-5)
231
+ rope_theta: float = attrs.field(default=500000)
232
+ apply_abs_pos_emb: bool = attrs.field(default=False)
233
+ max_batch_size: int = attrs.field(default=1)
234
+ max_seq_len: int = attrs.field(default=8192)
235
+ fuse_qkv: bool = attrs.field(default=False)
236
+ causal_mask: bool = attrs.field(default=True)
237
+ flash_attn: bool = attrs.field(default=True)
238
+ norm_type: str = attrs.field(default="rmsnorm")
239
+ backend: str = attrs.field(default="pytorch")
240
+ precision: str = attrs.field(default="bfloat16")
241
+ ema: config.EMAConfig = config.EMAConfig(enabled=False)
242
+ embedding_dropout: float = 0.0
243
+ attention_dropout: float = 0.0
244
+ hidden_dropout: float = 0.0
245
+ use_qk_normalization: bool = False
246
+ tokenizer: Optional[TokenizerConfig] = None
247
+ inference: bool = False
248
+ act_ckpt_enabled: bool = False
249
+ fsdp_enabled: bool = False
250
+ context_parallel_size: int = attrs.field(default=1)
251
+ tensor_model_parallel_size: int = attrs.field(default=1)
252
+ sequence_parallel: bool = attrs.field(default=False)
253
+ set_parallel_mode: bool = attrs.field(default=False)
254
+ fsdp: LazyDict = LazyDict(
255
+ dict(
256
+ policy="auto", # choices: ["size", "auto"]
257
+ min_num_params=1024, # Used as policy == "size"
258
+ sharding_strategy="hybrid", # Choices: ["full", "hybrid"]. "full" means sharding_group_size = world_size
259
+ sharding_group_size=8, # If None, defaults to min(world_size, 8). Recommends 8 for training on 8-GPU nodes.
260
+ )
261
+ )
262
+ ckpt_dir: Optional[str] = attrs.field(default="")
263
+ ckpt_path: Optional[str] = attrs.field(
264
+ default=None
265
+ ) # If not None, load the model from this path instead of ckpt_dir
266
+ cache_dir: Optional[str] = attrs.field(default="/project/cosmos/ar/cache")
267
+ apply_yarn: Optional[bool] = attrs.field(default=False)
268
+ yarn_scale: Optional[float] = attrs.field(default=None)
269
+ yarn_beta_fast: Optional[int] = attrs.field(default=None)
270
+ yarn_beta_slow: Optional[int] = attrs.field(default=None)
271
+ original_seq_len: Optional[int] = attrs.field(default=None)
272
+ depth_init: bool = attrs.field(default=True)
273
+ ignore_first_num_tokens: int = 0
274
+ z_loss_coeff: float = 1e-4
275
+ attention_tp: bool = False
276
+ vision_encoder: Optional[str] = attrs.field(default=None)
277
+ mm_projector: Optional[str] = attrs.field(default=None)
278
+ rope_dim: Optional[str] = attrs.field(default="1D")
279
+ pytorch_rope_version: Optional[str] = attrs.field(default="v2")
280
+ original_latent_shape: Optional[list] = None
281
+ pad_to_multiple_of: Optional[int] = None
282
+ peft_last_n_layers: Optional[int] = attrs.field(default=0)
283
+ peft_every_n_layers: Optional[int] = attrs.field(default=0)
284
+ freeze_vision_encoder: bool = False
285
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
286
+ insert_cross_attn: bool = False
287
+ insert_cross_attn_every_k_layers: int = 1
288
+ context_dim: Optional[int] = attrs.field(default=1024)
289
+ finetune_layers_with_cross_attn: bool = False
290
+ finetune_layers_without_cross_attn: bool = False
291
+ use_action_condition: bool = False
292
+ action_embedding_mode: Optional[str] = attrs.field(default="mlp")
293
+ action_dim: Optional[int] = attrs.field(default=_ACTION_DIM)
294
+ action_embedding_dim: Optional[int] = attrs.field(default=1024)
295
+ group_causal_mask_mode: Optional[str] = attrs.field(default=None)
296
+ sync_1d_parameters: bool = True
297
+ # hyper-parameters for the medusa head configs
298
+ insert_medusa_head: bool = False
299
+ ft_medusa_option: str = "fft"
300
+ medusa_num_heads: int = 7
301
+ medusa_num_layers: int = 1
302
+ medusa_concat_heads: bool = True
303
+ # For video training
304
+ num_video_frames: Optional[int] = None
305
+ # Raw video pixel dimension
306
+ video_height: Optional[int] = None
307
+ video_width: Optional[int] = None
308
+ # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
309
+ video_latent_shape: Optional[list] = None
310
+ # For image training
311
+ image_latent_shape: Optional[list] = None
312
+ # For robot training (action)
313
+ zero_init_cross_attn_proj: bool = False
314
+ # For robot training (action)
315
+ concat_action_to_context: bool = False
316
+
317
+ def __getitem__(self, item):
318
+ return getattr(self, item)
cosmos_predict1/autoregressive/configs/base/model_config.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ from typing import Callable, List, Optional
18
+
19
+ import torch
20
+ from megatron.core import ModelParallelConfig
21
+
22
+ from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TrainingModelConfig
23
+ from cosmos_predict1.autoregressive.configs.base.tokenizer import (
24
+ TextTokenizerConfig,
25
+ TokenizerConfig,
26
+ VideoTokenizerConfig,
27
+ create_discrete_video_fsq_tokenizer_state_dict_config,
28
+ )
29
+ from cosmos_predict1.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer
30
+ from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer
31
+ from cosmos_predict1.autoregressive.training.model import AutoRegressiveTrainingModel
32
+ from cosmos_predict1.utils import log
33
+ from cosmos_predict1.utils.config import EMAConfig
34
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
35
+
36
+ # Common architecture specifications
37
+ BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336}
38
+ COSMOS_ARCHITECTURES = {
39
+ "1b": {
40
+ "n_layers": 16,
41
+ "dim": 2048,
42
+ "n_heads": 32,
43
+ },
44
+ "4b": {
45
+ "n_layers": 16,
46
+ "dim": 4096,
47
+ "n_heads": 32,
48
+ },
49
+ "12b": {
50
+ "n_layers": 40,
51
+ "dim": 5120,
52
+ "n_heads": 32,
53
+ "head_dim": 128,
54
+ },
55
+ }
56
+
57
+ COSMOS_YARN_CONFIG = {
58
+ "original_latent_shape": [3, 40, 64],
59
+ "apply_yarn": True,
60
+ "yarn_beta_fast": 4,
61
+ "yarn_beta_slow": 1,
62
+ "yarn_scale": 2,
63
+ }
64
+
65
+ # Llama3 architecture specifications for different model sizes
66
+ LLAMA3_ARCHITECTURES = {
67
+ "8b": {
68
+ "n_layers": 32,
69
+ "dim": 4096,
70
+ "n_heads": 32,
71
+ "ffn_hidden_size": 14336,
72
+ },
73
+ }
74
+ # Llama3.1 uses YaRN for long context support (context of 128k tokens)
75
+ LLAMA_YARN_CONFIG = {
76
+ "apply_yarn": True,
77
+ "yarn_scale": 8,
78
+ "yarn_beta_fast": 4,
79
+ "yarn_beta_slow": 1,
80
+ }
81
+
82
+ # Mistral architecture specifications for different model sizes
83
+ MISTRAL_ARCHITECTURES = {
84
+ "12b": {
85
+ "n_layers": 40,
86
+ "dim": 5120,
87
+ "n_heads": 32,
88
+ "ffn_hidden_size": 14336,
89
+ "head_dim": 128,
90
+ },
91
+ }
92
+
93
+ PIXTRAL_VISION_ARCHITECTURES = {
94
+ "12b": {"vision_encoder": "pixtral-12b-vit", "mm_projector": "mlp"},
95
+ }
96
+
97
+
98
+ def get_model_arch_specs(model_size: str, model_family: str = "mistral", pretrained: bool = False) -> dict:
99
+ """
100
+ Get the model architecture specifications for the given model size, model family and pretrained status.
101
+
102
+ Args:
103
+ model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", etc.
104
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral"
105
+ pretrained (bool): Whether to load pretrained weights.
106
+
107
+ Returns:
108
+ dict: A dictionary containing the model architecture specifications.
109
+ """
110
+ arch_specs = copy.deepcopy(BASE_CONFIG)
111
+ model_size = model_size.lower()
112
+ if model_family.startswith("cosmos"):
113
+ arch_specs.update(COSMOS_ARCHITECTURES[model_size])
114
+ elif model_family.startswith("llama"):
115
+ arch_specs.update(LLAMA3_ARCHITECTURES[model_size])
116
+ elif model_family in ["mistral", "pixtral"]:
117
+ arch_specs.update(MISTRAL_ARCHITECTURES[model_size])
118
+ if model_family == "pixtral":
119
+ arch_specs.update(PIXTRAL_VISION_ARCHITECTURES[model_size])
120
+ else:
121
+ raise ValueError(f"Model family {model_family} is not supported.")
122
+
123
+ if pretrained:
124
+ if model_family == "cosmos":
125
+ if model_size == "12b":
126
+ arch_specs.update(COSMOS_YARN_CONFIG)
127
+ log.debug(f"Using YaRN for RoPE extension with config: {COSMOS_YARN_CONFIG}")
128
+ else:
129
+ pass
130
+ elif model_family in ["llama", "llama3"]:
131
+ pretrained_specs = {
132
+ "rope_theta": 500000,
133
+ "max_seq_len": 8192,
134
+ "vocab_size": 128256,
135
+ }
136
+ arch_specs.update(pretrained_specs)
137
+ elif model_family == "llama3.1":
138
+ pretrained_specs = {
139
+ "rope_theta": 500000,
140
+ "max_seq_len": 131072,
141
+ "original_seq_len": 8192,
142
+ "vocab_size": 128256,
143
+ **LLAMA_YARN_CONFIG,
144
+ }
145
+ arch_specs.update(pretrained_specs)
146
+ elif model_family == "mistral":
147
+ assert model_size == "12b", "We only support Mistral-Nemo-12B model."
148
+ pretrained_specs = {
149
+ "rope_theta": 1000000,
150
+ "max_seq_len": 128000,
151
+ "vocab_size": 131072,
152
+ }
153
+ arch_specs.update(pretrained_specs)
154
+ elif model_family == "pixtral":
155
+ assert model_size == "12b", "We only support Pixtral 12B model."
156
+ pretrained_specs = {"rope_theta": 1000000000, "max_seq_len": 128000, "vocab_size": 131072}
157
+ arch_specs.update(pretrained_specs)
158
+ else:
159
+ raise ValueError(f"Model family {model_family} doesn't have a pretrained config.")
160
+
161
+ return arch_specs
162
+
163
+
164
+ def create_text_model_config(
165
+ model_ckpt_path: str,
166
+ tokenizer_path: str,
167
+ tensor_model_parallel_size: int = 1,
168
+ model_family: str = "mistral",
169
+ model_size: str = "12b",
170
+ is_instruct_model: bool = True,
171
+ max_seq_len: int = None,
172
+ max_batch_size: int = 1,
173
+ rope_dim: str = "1D",
174
+ add_special_tokens: bool = True,
175
+ pytorch_rope_version: str = None,
176
+ ) -> dict:
177
+ """Create a text model for training or inference.
178
+ Args:
179
+ model_ckpt_path (str): Path to the model checkpoint.
180
+ tokenizer_path (str): Path to the tokenizer folder.
181
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
182
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
183
+ model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", "8b", "72b", etc.
184
+ is_instruct_model (bool): Whether the model is an instruct model.
185
+ inference (bool): Whether to create the model for inference.
186
+ max_seq_len (int): Maximum sequence length.
187
+ max_batch_size (int): Maximum batch size.
188
+ rope_dim (str): RoPE dimension. Choices: "1D", "3D".
189
+ add_special_tokens (bool): Whether to add special tokens.
190
+ Returns:
191
+ dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
192
+ """
193
+ # Model size specific parameters
194
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
195
+ if max_seq_len is not None:
196
+ # Override the max_seq_len if provided
197
+ model_arch_specs["max_seq_len"] = max_seq_len
198
+ if pytorch_rope_version is not None:
199
+ model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
200
+ model_config = ModelConfig(
201
+ max_batch_size=max_batch_size,
202
+ precision="bfloat16",
203
+ ckpt_path=model_ckpt_path,
204
+ use_qk_normalization=False,
205
+ tensor_model_parallel_size=tensor_model_parallel_size,
206
+ rope_dim=rope_dim,
207
+ **model_arch_specs,
208
+ )
209
+
210
+ tokenizer_config = TokenizerConfig(
211
+ text_tokenizer=TextTokenizerConfig(
212
+ config=L(TextTokenizer)(
213
+ model_family=model_family,
214
+ is_instruct_model=is_instruct_model,
215
+ local_path=tokenizer_path,
216
+ ),
217
+ data_key="text",
218
+ tokenizer_offset=model_config.vocab_size,
219
+ tokenize_here=False,
220
+ vocab_size=model_config.vocab_size,
221
+ ),
222
+ seq_len=model_config.max_seq_len,
223
+ training_type="text_only",
224
+ add_special_tokens=add_special_tokens,
225
+ )
226
+ return model_config, tokenizer_config
227
+
228
+
229
+ def create_vision_language_model_config(
230
+ model_ckpt_path: str,
231
+ tokenizer_ckpt_path: str,
232
+ tensor_model_parallel_size: int = 1,
233
+ model_family: str = "pixtral",
234
+ model_size: str = "12b",
235
+ is_instruct_model: bool = True,
236
+ max_batch_size: int = 1,
237
+ rope_dim: str = "1D",
238
+ add_special_tokens: bool = True,
239
+ max_seq_len: int = None,
240
+ vision_encoder_in_channels: int = 3,
241
+ fuse_qkv: bool = False,
242
+ pytorch_rope_version: str = None,
243
+ ) -> dict:
244
+ """Create a vision-language model for training or inference.
245
+ Args:
246
+ model_ckpt_path (str): Path to the model checkpoint.
247
+ tokenizer_ckpt_path (str): Path to the tokenizer checkpoint.
248
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
249
+ model_family (str): Model family. Choices: "pixtral".
250
+ model_size (str): Model size. Choices: "12b".
251
+ is_instruct_model (bool): Whether the model is an instruct model.
252
+ rope_dim (str): RoPE dimension. Choices: "1D".
253
+ add_special_tokens (bool): Whether to add special tokens.
254
+ max_seq_len (int): Maximum sequence length.
255
+ vision_encoder_in_channels (int): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4 channel images where last channel is binary mask, set this to 4.
256
+ fuse_qkv (bool): Whether to fuse the QKV linear layers.
257
+ Returns:
258
+ dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
259
+ """
260
+ # Model size specific parameters
261
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
262
+ if max_seq_len is not None:
263
+ # Override the max_seq_len if provided
264
+ model_arch_specs["max_seq_len"] = max_seq_len
265
+ if pytorch_rope_version is not None:
266
+ model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
267
+
268
+ model_config = ModelConfig(
269
+ max_batch_size=max_batch_size,
270
+ precision="bfloat16",
271
+ ckpt_path=model_ckpt_path,
272
+ use_qk_normalization=False,
273
+ tensor_model_parallel_size=tensor_model_parallel_size,
274
+ rope_dim=rope_dim,
275
+ vision_encoder_in_channels=vision_encoder_in_channels,
276
+ fuse_qkv=fuse_qkv,
277
+ **model_arch_specs,
278
+ )
279
+ # Vision-language tokenizer
280
+ tokenizer_config = TokenizerConfig(
281
+ text_tokenizer=TextTokenizerConfig(
282
+ config=L(ImageTextTokenizer)(
283
+ model_family=model_family,
284
+ is_instruct_model=is_instruct_model,
285
+ image_processor_path=tokenizer_ckpt_path,
286
+ tokenizer_path=tokenizer_ckpt_path,
287
+ ),
288
+ data_key="image_text_interleaved",
289
+ tokenizer_offset=model_config.vocab_size,
290
+ tokenize_here=False,
291
+ vocab_size=model_config.vocab_size,
292
+ ),
293
+ seq_len=model_config.max_seq_len,
294
+ training_type="image_text_interleaved",
295
+ add_special_tokens=add_special_tokens,
296
+ )
297
+ return model_config, tokenizer_config
298
+
299
+
300
+ def create_video2world_model_config(
301
+ model_ckpt_path: str,
302
+ tokenizer_ckpt_path: str,
303
+ tensor_model_parallel_size: int = 1,
304
+ model_family: str = "cosmos",
305
+ model_size: str = "4b",
306
+ pixel_chunk_duration: int = 9,
307
+ num_video_frames: int = 36,
308
+ compression_ratio: List[int] = [8, 16, 16],
309
+ original_seq_len: int = 8192,
310
+ num_condition_latents_t: int = 1,
311
+ num_tokens_to_ignore: int = -1,
312
+ batch_size: int = 2,
313
+ video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
314
+ rope_dim: str = "3D",
315
+ add_special_tokens: bool = True,
316
+ video_height: int = 384,
317
+ video_width: int = 640,
318
+ use_qk_normalization: bool = True,
319
+ insert_cross_attn: bool = False,
320
+ insert_cross_attn_every_k_layers: int = 1,
321
+ context_dim: int = 1024,
322
+ training_type: str = "video_to_video",
323
+ pad_to_multiple_of: Optional[int] = 64,
324
+ vocab_size: int = 64000,
325
+ apply_abs_pos_emb: bool = False,
326
+ ) -> dict:
327
+ """Create a video-to-world model config.
328
+ Args:
329
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
330
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
331
+ model_size (str): Model size. Choices: "1b", "8b", "3b".
332
+ pixel_chunk_duration (int): Number of frames in each chunk.
333
+ num_video_frames (int): Number of video frames.
334
+ compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
335
+ original_seq_len (int): Original sequence length.
336
+ apply_yarn (bool): Whether to apply YaRN for long context scaling.
337
+ yarn_beta_fast (Optional[int]): Fast beta for YaRN.
338
+ yarn_beta_slow (Optional[int]): Slow beta for YaRN.
339
+ yarn_scale (Optional[int]): Scale factor for ctx extension.
340
+ use_qk_normalization (bool): Whether to use Query-Key normalization.
341
+ training_type (str): Type of training task.
342
+ batch_size (int): Batch size.
343
+ video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
344
+ video_tokenizer_version (str): Version of the video tokenizer.
345
+ num_condition_latents_t (int): Number of conditioning latent channels
346
+ num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
347
+ video_height (int): Height of the video frame. Defaults to 384.
348
+ video_width (int): Width of the video frame. Defaults to 640.
349
+ rope_dim (str): RoPE dimension. Choices: "1D", "3D".
350
+ add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
351
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
352
+ vocab_size (int): Vocabulary size.
353
+ apply_abs_pos_emb (bool): Whether to apply absolute positional embeddings.
354
+ Returns:
355
+ dict: A dictionary containing the model configuration representing the model object, can be instantiated.
356
+ """
357
+ assert (
358
+ pixel_chunk_duration % compression_ratio[0] == 1
359
+ ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
360
+ latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
361
+ latent_height = video_height // compression_ratio[1]
362
+ latent_width = video_width // compression_ratio[2]
363
+ # Do some math to compute the video latent shape and sequence length
364
+ assert (
365
+ num_video_frames % pixel_chunk_duration == 0
366
+ ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
367
+ video_latent_shape = [
368
+ num_video_frames // pixel_chunk_duration * latent_chunk_duration,
369
+ latent_height,
370
+ latent_width,
371
+ ]
372
+ # product of video_latent_shape
373
+ num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
374
+ if add_special_tokens:
375
+ seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
376
+ seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
377
+ # for text to video, we need to add <bov> token to indicate the start of the video
378
+ elif training_type == "text_to_video":
379
+ seq_len = num_token_video_latent + 1
380
+ else:
381
+ seq_len = num_token_video_latent
382
+
383
+ if seq_len % pad_to_multiple_of != 0:
384
+ # Round up to the nearest multiple of pad_to_multiple_of
385
+ seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
386
+
387
+ # Model size specific parameters
388
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
389
+
390
+ # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
391
+ # If num_tokens_to_ignore is specified, use it.
392
+ # Else compute it from num_condition_latents_t
393
+ if num_tokens_to_ignore < 0:
394
+ num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
395
+ if not add_special_tokens and num_condition_latents_t > 0:
396
+ # If there are no special tokens (bov), do a -1 so that you can compute the loss
397
+ # from the first token of the next chunk
398
+ num_tokens_to_ignore -= 1
399
+
400
+ model_config = ModelConfig(
401
+ video_height=video_height,
402
+ video_width=video_width,
403
+ max_seq_len=seq_len,
404
+ max_batch_size=batch_size,
405
+ precision="bfloat16",
406
+ ckpt_path=model_ckpt_path,
407
+ use_qk_normalization=use_qk_normalization,
408
+ vocab_size=64000,
409
+ original_seq_len=original_seq_len,
410
+ tensor_model_parallel_size=tensor_model_parallel_size,
411
+ video_latent_shape=video_latent_shape,
412
+ num_video_frames=num_video_frames,
413
+ rope_dim=rope_dim,
414
+ pad_to_multiple_of=pad_to_multiple_of,
415
+ insert_cross_attn=insert_cross_attn,
416
+ insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
417
+ context_dim=context_dim,
418
+ apply_abs_pos_emb=apply_abs_pos_emb,
419
+ **model_arch_specs,
420
+ )
421
+
422
+ video_tokenizer_config = video_tokenizer_config_creator(
423
+ tokenizer_ckpt_path, pixel_chunk_duration, compression_ratio
424
+ )
425
+ tokenizer_config = TokenizerConfig(
426
+ text_tokenizer=None,
427
+ video_tokenizer=VideoTokenizerConfig(
428
+ config=video_tokenizer_config,
429
+ data_key="video",
430
+ tokenizer_offset=0, # Since there is no text embeddings in the model. Note this only apply when the model is trained from scratch. If we use text pretrained model, the offset will be vocab_size of text token.
431
+ tokenize_here=True,
432
+ max_seq_len=num_token_video_latent,
433
+ vocab_size=vocab_size,
434
+ ),
435
+ seq_len=seq_len,
436
+ training_type=training_type,
437
+ add_special_tokens=add_special_tokens,
438
+ pad_to_multiple_of=pad_to_multiple_of,
439
+ )
440
+ return model_config, tokenizer_config
441
+
442
+
443
+ def create_video2world_model(
444
+ tensor_model_parallel_size: int = 1,
445
+ context_parallel_size: int = 1,
446
+ shard_checkpoint: bool = False,
447
+ model_family: str = "cosmos",
448
+ model_size: str = "1b",
449
+ backend: str = "pytorch",
450
+ pixel_chunk_duration: int = 9,
451
+ num_video_frames: int = 36,
452
+ compression_ratio: List[int] = [8, 16, 16],
453
+ original_seq_len: int = 8192,
454
+ apply_yarn: bool = False,
455
+ yarn_beta_fast: Optional[int] = None,
456
+ yarn_beta_slow: Optional[int] = None,
457
+ yarn_scale: Optional[int] = None,
458
+ num_condition_latents_t: int = 1,
459
+ num_tokens_to_ignore: int = -1,
460
+ batch_size: int = 1,
461
+ fsdp_enabled: bool = False,
462
+ act_ckpt_enabled: bool = False,
463
+ video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
464
+ rope_dim: str = "3D",
465
+ add_special_tokens: bool = False,
466
+ video_height: int = 384,
467
+ video_width: int = 640,
468
+ original_latent_shape: Optional[List[int]] = None,
469
+ use_qk_normalization: bool = True,
470
+ sequence_parallel: bool = False,
471
+ insert_cross_attn: bool = False,
472
+ insert_cross_attn_every_k_layers: int = 1,
473
+ context_dim: int = 1024,
474
+ finetune_layers_with_cross_attn: bool = False,
475
+ finetune_layers_without_cross_attn: bool = False,
476
+ use_action_condition: bool = False,
477
+ action_embedding_mode: Optional[str] = "mlp",
478
+ action_dim: int = 8, # ACTION_DIM,
479
+ action_embedding_dim: int = 1024,
480
+ group_causal_mask_mode: Optional[str] = None,
481
+ training_type: str = "video_to_video",
482
+ pad_to_multiple_of: Optional[int] = 1,
483
+ z_loss_coeff: float = 1e-4,
484
+ temporal_overlap: int = 0,
485
+ embedding_dropout: float = 0.0,
486
+ insert_medusa_head: bool = False,
487
+ ft_medusa_option: str = "fft",
488
+ medusa_num_heads: int = 7,
489
+ medusa_num_layers: int = 1,
490
+ medusa_concat_heads: bool = True,
491
+ fuse_qkv: bool = False,
492
+ zero_init_cross_attn_proj: bool = False,
493
+ concat_action_to_context: bool = False,
494
+ tokenizer_ckpt_path: str = "checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/ema.jit",
495
+ ) -> dict:
496
+ """Create a video-to-video model for training.
497
+ Args:
498
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
499
+ context_parallel_size (int): Number of context parallel groups.
500
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
501
+ model_size (str): Model size. Choices: "1b", "8b", "3b".
502
+ backend (str): Backend for the model. Choices: "pytorch", "transformer_engine".
503
+ pixel_chunk_duration (int): Number of frames in each chunk.
504
+ num_video_frames (int): Number of video frames.
505
+ compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
506
+ original_seq_len (int): Original sequence length.
507
+ apply_yarn (bool): Whether to apply YaRN for long context scaling.
508
+ yarn_beta_fast (Optional[int]): Fast beta for YaRN.
509
+ yarn_beta_slow (Optional[int]): Slow beta for YaRN.
510
+ yarn_scale (Optional[int]): Scale factor for ctx extension.
511
+ fsdp_enabled (bool): Whether Fully Sharded Data Parallel (FSDP) is enabled.
512
+ act_ckpt_enabled (bool): Whether activation checkpointing is enabled.
513
+ use_qk_normalization (bool): Whether to use Query-Key normalization.
514
+ training_type (str): Type of training task.
515
+ batch_size (int): Batch size.
516
+ video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
517
+ video_tokenizer_version (str): Version of the video tokenizer.
518
+ num_condition_latents_t (int): Number of conditioning latent channels
519
+ num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
520
+ video_height (int): Height of the video frame. Defaults to 384.
521
+ video_width (int): Width of the video frame. Defaults to 640.
522
+ rope_dim (str): RoPE dimension. Choices: "1D", "2D", "3D".
523
+ add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
524
+ original_latent_shape (list): Original latent shape before RoPE scaling.
525
+ sequence_parallel (bool): Whether to enable sequence parallelism.
526
+ insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
527
+ insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
528
+ context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
529
+ finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn).
530
+ finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn).
531
+ use_action_condition (bool): Whether to use action condition.
532
+ action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp".
533
+ action_dim (int): Dimension of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]).
534
+ action_embedding_dim (int): Dimension of the action embedding.
535
+ group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal".
536
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
537
+ z_loss_coeff (float): Coefficient for the z loss.
538
+ temporal_overlap (int): Temporal overlap in the latent space.
539
+ embedding_dropout (float): Dropout rate for the embeddings.
540
+ insert_medusa_head (bool): Whether to insert the Medusa head.
541
+ ft_medusa_option (str): Options on which layers to finetune, choices like:
542
+ "fft": fully fine-tune both medusa heads and all LLM backbone;
543
+ "head": fine-tune medusa heads;
544
+ "head_out": fine-tune medusa heads, and the output layer;
545
+ "head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone.
546
+ medusa_num_heads (int): Number of heads in the Medusa head.
547
+ medusa_num_layers (int): Number of layers in the Medusa head.
548
+ medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1.
549
+ fuse_qkv (bool): Whether to fuse the QKV linear layers.
550
+ zero_init_cross_attn_proj (bool): Whether to zero-initialize the cross-attention projection weights (default False).
551
+ concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False).
552
+ Returns:
553
+ dict: A dictionary containing the model configuration representing the model object, can be instantiated.
554
+ """
555
+ assert (
556
+ pixel_chunk_duration % compression_ratio[0] == 1
557
+ ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
558
+ latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
559
+ latent_height = video_height // compression_ratio[1]
560
+ latent_width = video_width // compression_ratio[2]
561
+ # Compute the video latent shape and sequence length
562
+ if temporal_overlap == 0:
563
+ assert (
564
+ num_video_frames % pixel_chunk_duration == 0
565
+ ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
566
+ video_latent_shape = [
567
+ num_video_frames // pixel_chunk_duration * latent_chunk_duration,
568
+ latent_height,
569
+ latent_width,
570
+ ]
571
+
572
+ else:
573
+ # Calculate temporal overlap in the latent space
574
+ temporal_overlap_latent = temporal_overlap // compression_ratio[0]
575
+
576
+ # Calculate the effective number of latent chunks for the video
577
+ latent_chunks = (num_video_frames - temporal_overlap) // (pixel_chunk_duration - temporal_overlap)
578
+
579
+ # Compute the total duration of the latent chunks, accounting for overlap
580
+ effective_latent_duration = (
581
+ latent_chunk_duration - temporal_overlap_latent
582
+ ) * latent_chunks + temporal_overlap_latent
583
+
584
+ # Define the shape of the video in the latent space
585
+ video_latent_shape = [
586
+ effective_latent_duration, # Temporal dimension
587
+ latent_height, # Height in the latent space
588
+ latent_width, # Width in the latent space
589
+ ]
590
+
591
+ # product of video_latent_shape
592
+ num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
593
+ if add_special_tokens:
594
+ seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
595
+ seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
596
+ # for text to video, we need to add <bov> token to indicate the start of the video
597
+ elif training_type == "text_to_video":
598
+ seq_len = num_token_video_latent + 1
599
+ else:
600
+ seq_len = num_token_video_latent
601
+
602
+ if seq_len % pad_to_multiple_of != 0:
603
+ # Round up to the nearest multiple of pad_to_multiple_of
604
+ seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
605
+
606
+ # Model size specific parameters
607
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=False)
608
+
609
+ inference = False # False for training, True for inference
610
+ # set_parallel_mode = True
611
+ set_parallel_mode = tensor_model_parallel_size > 1
612
+ attention_tp = True
613
+
614
+ if context_parallel_size > 1:
615
+ assert backend == "transformer_engine", "Context parallelism is only supported in transformer engine."
616
+
617
+ if tensor_model_parallel_size > 1:
618
+ assert set_parallel_mode, "Tensor model parallelism is only supported in parallel mode."
619
+
620
+ # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
621
+ # If num_tokens_to_ignore is specified, use it.
622
+ # Else compute it from num_condition_latents_t
623
+ if num_tokens_to_ignore < 0:
624
+ num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
625
+ if not add_special_tokens and num_condition_latents_t > 0:
626
+ # If there are no special tokens (bov), do a -1 so that you can compute the loss
627
+ # from the first token of the next chunk
628
+ num_tokens_to_ignore -= 1
629
+
630
+ model_config = TrainingModelConfig(
631
+ video_height=video_height,
632
+ video_width=video_width,
633
+ max_seq_len=seq_len,
634
+ max_batch_size=batch_size,
635
+ inference=inference,
636
+ backend=backend,
637
+ precision="bfloat16",
638
+ ema=EMAConfig(enabled=False),
639
+ act_ckpt_enabled=act_ckpt_enabled,
640
+ fsdp_enabled=fsdp_enabled,
641
+ cache_dir=None,
642
+ ckpt_path="checkpoints/Cosmos-Predict1-4B/model.pt",
643
+ use_qk_normalization=use_qk_normalization,
644
+ vocab_size=64000,
645
+ ignore_first_num_tokens=num_tokens_to_ignore,
646
+ apply_yarn=apply_yarn,
647
+ yarn_beta_fast=yarn_beta_fast,
648
+ yarn_beta_slow=yarn_beta_slow,
649
+ original_seq_len=original_seq_len,
650
+ yarn_scale=yarn_scale,
651
+ context_parallel_size=context_parallel_size,
652
+ tensor_model_parallel_size=tensor_model_parallel_size,
653
+ set_parallel_mode=set_parallel_mode,
654
+ attention_tp=attention_tp,
655
+ video_latent_shape=video_latent_shape,
656
+ num_video_frames=num_video_frames,
657
+ rope_dim=rope_dim,
658
+ original_latent_shape=original_latent_shape,
659
+ pad_to_multiple_of=pad_to_multiple_of,
660
+ sequence_parallel=sequence_parallel,
661
+ insert_cross_attn=insert_cross_attn,
662
+ insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
663
+ context_dim=context_dim,
664
+ finetune_layers_with_cross_attn=finetune_layers_with_cross_attn,
665
+ finetune_layers_without_cross_attn=finetune_layers_without_cross_attn,
666
+ use_action_condition=use_action_condition,
667
+ action_embedding_mode=action_embedding_mode,
668
+ action_dim=action_dim,
669
+ action_embedding_dim=action_embedding_dim,
670
+ group_causal_mask_mode=group_causal_mask_mode,
671
+ z_loss_coeff=z_loss_coeff,
672
+ embedding_dropout=embedding_dropout,
673
+ insert_medusa_head=insert_medusa_head,
674
+ ft_medusa_option=ft_medusa_option,
675
+ medusa_num_heads=medusa_num_heads,
676
+ medusa_num_layers=medusa_num_layers,
677
+ medusa_concat_heads=medusa_concat_heads,
678
+ fuse_qkv=fuse_qkv,
679
+ zero_init_cross_attn_proj=zero_init_cross_attn_proj,
680
+ concat_action_to_context=concat_action_to_context,
681
+ **model_arch_specs,
682
+ )
683
+
684
+ tokenizer_config = TokenizerConfig(
685
+ text_tokenizer=None,
686
+ video_tokenizer=VideoTokenizerConfig(
687
+ config=video_tokenizer_config_creator(
688
+ ckpt_path=tokenizer_ckpt_path, pixel_chunk_duration=pixel_chunk_duration
689
+ ),
690
+ data_key="video",
691
+ tokenizer_offset=0,
692
+ vocab_size=64000,
693
+ tokenize_here=True,
694
+ max_seq_len=num_token_video_latent,
695
+ temporal_overlap=temporal_overlap,
696
+ ),
697
+ seq_len="${model.model_config.max_seq_len}",
698
+ training_type=training_type,
699
+ add_special_tokens=add_special_tokens,
700
+ pad_to_multiple_of=pad_to_multiple_of,
701
+ )
702
+
703
+ model_parallel = ModelParallelConfig(
704
+ bf16=True,
705
+ params_dtype=getattr(torch, "bfloat16"),
706
+ )
707
+ model_parallel.tensor_model_parallel_size = "${model.model_config.tensor_model_parallel_size}"
708
+ model_parallel.context_parallel_size = "${model.model_config.context_parallel_size}"
709
+ model_parallel.sequence_parallel = "${model.model_config.sequence_parallel}"
710
+ return L(AutoRegressiveTrainingModel.build)(
711
+ seed=0,
712
+ train_from_scratch=True,
713
+ model_config=model_config,
714
+ fsdp_checkpointer=None,
715
+ tokenizer_config=tokenizer_config,
716
+ model_parallel=model_parallel,
717
+ shard_checkpoint=shard_checkpoint,
718
+ )
cosmos_predict1/autoregressive/configs/base/model_parallel.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from megatron.core import ModelParallelConfig
18
+
19
+ from cosmos_predict1.utils.lazy_config import LazyDict
20
+
21
+
22
+ def create_model_parallel_config():
23
+ model_parallel = ModelParallelConfig(bf16=True, params_dtype=getattr(torch, "bfloat16"))
24
+ model_parallel.tensor_model_parallel_size = "${model.model_parallel.tensor_model_parallel_size}"
25
+ model_parallel.context_parallel_size = "${model.model_parallel.context_parallel_size}"
26
+ model_parallel.sequence_parallel = "${model.model_parallel.sequence_parallel}"
27
+ MODEL_PARALLELS = LazyDict(
28
+ dict(
29
+ model_parallel_bf16=model_parallel,
30
+ ),
31
+ flags={"allow_objects": True},
32
+ )
33
+ return MODEL_PARALLELS["model_parallel_bf16"]
cosmos_predict1/autoregressive/configs/base/optim.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
19
+
20
+
21
+ class LambdaLinearWarmupScheduler:
22
+ """
23
+ A learning rate scheduler that implements linear warm-up and cool-down.
24
+
25
+ This scheduler provides three phases:
26
+ 1. Warm-up: Learning rate linearly increases from 0 to 1.
27
+ 2. Constant: Learning rate remains at 1.
28
+ 3. Cool-down: Learning rate linearly decreases from 1 to 0.
29
+
30
+ Args:
31
+ warmup_steps (int): Number of steps for the warm-up phase.
32
+ warmup_offset (int): Starts warmup from this offset.
33
+ max_iter (int, optional): Total number of iterations. Required if cooldown_steps is provided.
34
+ cooldown_steps (int, optional): Number of steps for the cool-down phase.
35
+
36
+ Raises:
37
+ ValueError: If cooldown_steps is provided without max_iter, or if an invalid step is given.
38
+ """
39
+
40
+ def __init__(self, warmup_steps: int, warmup_offset: int = 0, max_iter: int = None, cooldown_steps: int = None):
41
+ self.warmup_steps = warmup_steps
42
+ self.warmup_offset = warmup_offset
43
+ self.max_iter = max_iter
44
+ self.cooldown_steps = cooldown_steps
45
+
46
+ if cooldown_steps is not None:
47
+ if max_iter is None:
48
+ raise ValueError("max_iter must be specified when cooldown_steps is provided")
49
+ self.cooldown_start = max_iter - cooldown_steps
50
+ else:
51
+ self.cooldown_start = None
52
+
53
+ def __call__(self, step):
54
+ # Warm-up phase
55
+ if step < self.warmup_offset:
56
+ return 0
57
+
58
+ if step < self.warmup_steps + self.warmup_offset:
59
+ return float(step - self.warmup_offset) / float(max(1, self.warmup_steps))
60
+
61
+ # Constant phase (no cool-down)
62
+ elif self.cooldown_steps is None:
63
+ return 1.0
64
+
65
+ # Constant phase (before cool-down starts)
66
+ elif step < self.cooldown_start:
67
+ return 1.0
68
+
69
+ # Cool-down phase
70
+ elif self.cooldown_start <= step < self.max_iter:
71
+ cooldown_progress = (step - self.cooldown_start) / self.cooldown_steps
72
+ return 1.0 - cooldown_progress
73
+
74
+ # After max_iter
75
+ elif step >= self.max_iter:
76
+ return 0.0
77
+
78
+ # Unexpected case
79
+ else:
80
+ raise ValueError(f"Invalid step {step}")
81
+
82
+
83
+ LambdaLinearLR = L(torch.optim.lr_scheduler.LambdaLR)(
84
+ optimizer=None,
85
+ lr_lambda=L(LambdaLinearWarmupScheduler)(warmup_steps=5000),
86
+ )
cosmos_predict1/autoregressive/configs/base/tokenizer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQStateDictTokenizer
21
+ from cosmos_predict1.autoregressive.tokenizer.networks import CausalDiscreteVideoTokenizer
22
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
23
+ from cosmos_predict1.utils.lazy_config import LazyDict
24
+
25
+
26
+ def create_discrete_video_fsq_tokenizer_state_dict_config(
27
+ ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16]
28
+ ) -> LazyDict:
29
+ CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)(
30
+ # The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime.
31
+ # - It relies on fully 3D discrete wavelet transform
32
+ # - Uses a layer norm instead of a group norm
33
+ # - Factorizes full convolutions into spatial and temporal convolutions
34
+ # - Factorizes full attention into spatial and temporal attention
35
+ # - Strictly causal, with flexible temporal length at inference.
36
+ attn_resolutions=[32],
37
+ channels=128,
38
+ channels_mult=[2, 4, 4],
39
+ dropout=0.0,
40
+ in_channels=3,
41
+ num_res_blocks=2,
42
+ out_channels=3,
43
+ resolution=1024,
44
+ patch_size=4,
45
+ patch_method="haar",
46
+ z_channels=16,
47
+ z_factor=1,
48
+ num_groups=1,
49
+ legacy_mode=False,
50
+ spatial_compression=16,
51
+ temporal_compression=8,
52
+ embedding_dim=6,
53
+ levels=[8, 8, 8, 5, 5, 5],
54
+ name="CausalDiscreteFactorizedVideoTokenizer",
55
+ )
56
+
57
+ return L(DiscreteVideoFSQStateDictTokenizer)(
58
+ enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"),
59
+ dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"),
60
+ tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig,
61
+ name="discrete_video_fsq",
62
+ latent_ch=6,
63
+ is_bf16=True,
64
+ pixel_chunk_duration=pixel_chunk_duration,
65
+ latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0],
66
+ max_enc_batch_size=8,
67
+ max_dec_batch_size=4,
68
+ levels=[8, 8, 8, 5, 5, 5],
69
+ compression_ratio=compression_ratio,
70
+ )
71
+
72
+
73
+ @attrs.define(slots=False)
74
+ class TextTokenizerConfig:
75
+ """
76
+ Text tokenizer config
77
+
78
+ Args:
79
+ config: Config file to define the text tokenizer class.
80
+ data_key (str): The input key from data_dict that will be passed to the text tokenizer.
81
+ tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
82
+ tokenizer_offset (int): Offset that is added to the tokens.
83
+ vocab_size (int): Vocabulary size of the tokenizer.
84
+ """
85
+
86
+ config: LazyDict
87
+ data_key: str = ""
88
+ tokenize_here: bool = False
89
+ tokenizer_offset: int = 0
90
+ vocab_size: int = 0
91
+
92
+
93
+ @attrs.define(slots=False)
94
+ class VideoTokenizerConfig:
95
+ """
96
+ Video tokenizer config
97
+
98
+ Args:
99
+ config: Config file to define the video tokenizer class.
100
+ data_key (str): The input key from data_dict that will be passed to the video tokenizer.
101
+ tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
102
+ tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we
103
+ add an offset to make sure that video tokens and text tokens don't overlap.
104
+ vocab_size (int): Vocabulary size of the tokenizer.
105
+ max_seq_len (int): Maximum token length for an input video.
106
+ temporal_overlap (int): Overlap between consecutive video chunks.
107
+ """
108
+
109
+ config: LazyDict
110
+ data_key: str = ""
111
+ tokenize_here: bool = True
112
+ tokenizer_offset: int = 0
113
+ vocab_size: int = 0
114
+ max_seq_len: int = -1
115
+ temporal_overlap: int = 0
116
+
117
+
118
+ @attrs.define(slots=False)
119
+ class TokenizerConfig:
120
+ """
121
+ Joint tokenizer config
122
+
123
+ Args:
124
+ text_tokenizer (TextTokenizerConfig): Text tokenizer config file
125
+ class_tokenizer (ClassTokenizerConfig): Class tokenizer config file
126
+ video_tokenizer (VideoTokenizerConfig): Video tokenizer config file
127
+ image_tokenizer (ImageTokenizerConfig): Image tokenizer config file
128
+ seq_len (int): Final token sequence length
129
+ training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"]
130
+ add_special_tokens (bool): Whether to add special tokens to the output tokens
131
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
132
+ """
133
+
134
+ text_tokenizer: Optional[TextTokenizerConfig] = None
135
+ video_tokenizer: Optional[VideoTokenizerConfig] = None
136
+ seq_len: int = 4096
137
+ training_type: str = None
138
+ add_special_tokens: bool = True
139
+ pad_to_multiple_of: Optional[int] = 64
cosmos_predict1/autoregressive/configs/config.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Default config for cosmos_ar project."""
17
+
18
+ import os
19
+ from typing import Any, List
20
+
21
+ import attrs
22
+
23
+ from cosmos_predict1.autoregressive.configs.registry import register_configs
24
+ from cosmos_predict1.autoregressive.trainer import Trainer
25
+ from cosmos_predict1.utils import config, log
26
+ from cosmos_predict1.utils.config_helper import import_all_modules_from_package
27
+
28
+
29
+ @attrs.define(slots=False)
30
+ class Config(config.Config):
31
+ defaults: List[Any] = attrs.field(
32
+ factory=lambda: [
33
+ "_self_",
34
+ {"model": None},
35
+ {"data_train": "mock_video"},
36
+ {"data_val": None},
37
+ {"optimizer": "fused_adamw"},
38
+ {"scheduler": "warmup_cosine_lr"},
39
+ {"checkpoint": "local"},
40
+ {"callbacks": "basic"},
41
+ {"global_config": None},
42
+ {"experiment": None},
43
+ ]
44
+ )
45
+
46
+ def validate(self) -> None:
47
+ """Validate that the config has all required fields."""
48
+ assert self.job.project != "", "job.project is not set"
49
+ assert self.job.group != "", "job.group is not set"
50
+ assert self.job.name != "", "job.name is not set"
51
+ log.info("Validating config for cosmos_autoregressive job")
52
+ # FSDP config check
53
+ if self.model.model_config.fsdp_enabled:
54
+ assert self.trainer.distributed_parallelism == "fsdp"
55
+ else:
56
+ assert self.trainer.distributed_parallelism == "ddp"
57
+
58
+ # Transformer Engine config check
59
+ if self.model.model_config.backend == "transformer_engine":
60
+ assert (
61
+ "NVTE_FLASH_ATTN" in os.environ and os.environ["NVTE_FLASH_ATTN"] == "1"
62
+ ) # Enable Flash attention for transformer engine
63
+
64
+ # TP, CP config check
65
+ if self.model_parallel is not None:
66
+ if self.model_parallel.context_parallel_size > 1:
67
+ assert (
68
+ self.model.model_config.backend == "transformer_engine"
69
+ ), "Context parallelism is only supported in transformer engine."
70
+
71
+ if self.model_parallel.tensor_model_parallel_size > 1:
72
+ assert (
73
+ self.model.model_config.set_parallel_mode
74
+ ), "Tensor model parallelism is only supported in parallel mode."
75
+
76
+ if self.model_parallel.sequence_parallel:
77
+ assert (
78
+ self.model_parallel.tensor_model_parallel_size > 1
79
+ ), "Sequence parallelism is only supported in tensor model parallelism."
80
+ assert (
81
+ self.model.model_config.backend == "transformer_engine"
82
+ ), "Sequence parallelism is only supported in transformer engine."
83
+
84
+
85
+ def make_config():
86
+ c = Config(
87
+ model=None,
88
+ optimizer=None,
89
+ scheduler=None,
90
+ dataloader_train=None,
91
+ dataloader_val=None,
92
+ checkpoint=None,
93
+ )
94
+
95
+ c.job.project = "cosmos_autoregressive"
96
+ c.job.group = "debug"
97
+ c.job.name = "default_${now:%Y-%m-%d}_${now:%H-%M-%S}"
98
+
99
+ c.trainer.type = Trainer
100
+ c.trainer.run_validation = True
101
+
102
+ c.trainer.seed = 0
103
+ c.trainer.max_iter = 10
104
+ c.trainer.logging_iter = 1
105
+
106
+ c.trainer.callbacks = None
107
+ register_configs()
108
+ # experiment config are defined in the experiment folder
109
+ # call import_all_modules_from_package to register them
110
+ import_all_modules_from_package("cosmos_predict1.autoregressive.configs.experiment")
111
+ return c
cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py ADDED
File without changes
cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ This file contains a basic configuration for video2video experiments.
18
+ """
19
+
20
+ from hydra.core.config_store import ConfigStore
21
+
22
+ from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model
23
+ from cosmos_predict1.autoregressive.configs.base.model_parallel import create_model_parallel_config
24
+ from cosmos_predict1.utils import log
25
+ from cosmos_predict1.utils.lazy_config import LazyDict
26
+
27
+ cs = ConfigStore.instance()
28
+
29
+
30
+ """
31
+ Finetune 4B model with TP=1, pytorch backend, low resolution tealrobot data, frames 33, chunk 33.
32
+ Usage:
33
+ torchrun --nproc_per_node=1 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobotsmall_tp1
34
+ """
35
+ base_4b_example_tealrobotsmall_tp1: LazyDict = LazyDict(
36
+ dict(
37
+ defaults=[
38
+ {"override /data_train": "tealrobot_video_small"},
39
+ {
40
+ "override /callbacks": [
41
+ "basic",
42
+ "video_teacher_forcing",
43
+ ]
44
+ },
45
+ {"override /checkpoint": "local"},
46
+ {"override /optimizer": "fused_adamw"},
47
+ {"override /scheduler": "warmup_cosine_lr"},
48
+ "_self_",
49
+ ],
50
+ job=dict(
51
+ project="posttraining",
52
+ group="autoregressive_base",
53
+ name="base_4b_example_tealrobotsmall_tp1",
54
+ ),
55
+ model=create_video2world_model(
56
+ model_size="4b",
57
+ model_family="cosmos",
58
+ backend="pytorch",
59
+ tensor_model_parallel_size=1,
60
+ batch_size=1,
61
+ pixel_chunk_duration=33,
62
+ num_video_frames=33,
63
+ video_height=384,
64
+ video_width=640,
65
+ tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit",
66
+ add_special_tokens=False,
67
+ ),
68
+ trainer=dict(
69
+ max_iter=50000,
70
+ grad_accum_iter=1,
71
+ grad_scaler_args=dict(enabled=False),
72
+ run_validation=False, # No need for validation as epoch <= 1
73
+ distributed_parallelism="ddp",
74
+ callbacks=dict(
75
+ vid_sampling_tf=dict(
76
+ every_n=500,
77
+ ),
78
+ ),
79
+ ),
80
+ checkpoint=dict(
81
+ load_path="checkpoints/Cosmos-Predict1-4B/model.pt",
82
+ load_training_state=False,
83
+ strict_resume=True,
84
+ save_iter=1000,
85
+ ),
86
+ model_parallel=create_model_parallel_config(),
87
+ ),
88
+ )
89
+
90
+
91
+ """
92
+ Finetune 4B model with TP=4, pytorch backend, high resolution tealrobot data, frame 33, chunk 33.
93
+ Usage:
94
+ torchrun --nproc_per_node=4 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobot_tp4
95
+ """
96
+ base_4b_example_tealrobot_tp4: LazyDict = LazyDict(
97
+ dict(
98
+ defaults=[
99
+ {"override /data_train": "tealrobot_video"},
100
+ {
101
+ "override /callbacks": [
102
+ "basic",
103
+ "video_teacher_forcing",
104
+ ]
105
+ },
106
+ {"override /checkpoint": "local"},
107
+ {"override /optimizer": "fused_adamw"},
108
+ {"override /scheduler": "warmup_cosine_lr"},
109
+ "_self_",
110
+ ],
111
+ job=dict(
112
+ project="posttraining",
113
+ group="autoregressive_base",
114
+ name="base_4b_example_tealrobot_tp4",
115
+ ),
116
+ model=create_video2world_model(
117
+ model_size="4b",
118
+ model_family="cosmos",
119
+ backend="pytorch",
120
+ tensor_model_parallel_size=4,
121
+ batch_size=1,
122
+ pixel_chunk_duration=33,
123
+ num_video_frames=33,
124
+ video_height=640,
125
+ video_width=848,
126
+ tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit",
127
+ add_special_tokens=False,
128
+ ),
129
+ trainer=dict(
130
+ max_iter=50000,
131
+ grad_accum_iter=1,
132
+ grad_scaler_args=dict(enabled=False),
133
+ run_validation=False, # No need for validation as epoch <= 1
134
+ distributed_parallelism="ddp",
135
+ callbacks=dict(
136
+ vid_sampling_tf=dict(
137
+ every_n=500,
138
+ ),
139
+ ),
140
+ ),
141
+ checkpoint=dict(
142
+ load_path="checkpoints/Cosmos-Predict1-4B/model.pt",
143
+ load_training_state=False,
144
+ strict_resume=False,
145
+ save_iter=1000,
146
+ ),
147
+ model_parallel=create_model_parallel_config(),
148
+ ),
149
+ )
150
+
151
+
152
+ def register_experiments(cs):
153
+ # Register the experiments
154
+ for _item in [
155
+ base_4b_example_tealrobotsmall_tp1,
156
+ base_4b_example_tealrobot_tp4,
157
+ ]:
158
+ cs.store(
159
+ group="experiment",
160
+ package="_global_",
161
+ name=_item["job"]["name"],
162
+ node=_item,
163
+ )
cosmos_predict1/autoregressive/configs/inference/inference_config.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, List, Optional, Union
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TokenizerConfig
21
+
22
+
23
+ @attrs.define(slots=False)
24
+ class DataShapeConfig:
25
+ latent_shape: list = []
26
+ num_video_frames: Union[None, int] = None
27
+ height: Union[None, int] = None
28
+ width: Union[None, int] = None
29
+
30
+
31
+ @attrs.define(slots=False)
32
+ class SamplingConfig:
33
+ """
34
+ Sampling config
35
+ Args:
36
+ temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
37
+ top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
38
+ logprobs (bool): Flag indicating whether to compute token log probabilities. Defaults to False.
39
+ echo (bool): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
40
+
41
+ """
42
+
43
+ temperature: float = 0.6
44
+ top_k: int = None
45
+ top_p: float = 0.9
46
+ compile_prefill: bool = False
47
+ compile_sampling: bool = True
48
+ logprobs: bool = False
49
+ echo: bool = False
50
+
51
+
52
+ @attrs.define(slots=False)
53
+ class DiffusionDecoderSamplingConfig:
54
+ """
55
+ Diffusion decoder sampling config
56
+ Args:
57
+ guidance (float): Guidance scale for the diffusion process. Controls how much the model follows the conditioning. Defaults to 0.8.
58
+ sigma_min (float): Minimum noise level for the diffusion process. Defaults to 0.02.
59
+ sigma (float): Initial noise level for the diffusion process. Defaults to 8.
60
+ num_steps (int): Number of denoising steps to perform. Defaults to 35.
61
+ overlap (int): Number of overlapping frames between video chunks during processing. Defaults to 2.
62
+ continuous_tokenizer_channel (int): Number of channels in the continuous tokenizer of diffusion decoder. Defaults to 16.
63
+ continuous_tokenizer_spatial_compression_ratio (int): Spatial compression ratio for the continuous tokenizer of diffusion decoder. Defaults to 8.
64
+ dd_train_num_video_frames (int): Number of video frames used during training for diffusion decoder. Defaults to 57.
65
+ """
66
+
67
+ guidance: float = 1.8
68
+ sigma_min: float = 0.02
69
+ sigma: float = 8
70
+ num_steps: int = 15
71
+ overlap: int = 2
72
+ continuous_tokenizer_channel = 16
73
+ continuous_tokenizer_spatial_compression_ratio = 8
74
+ dd_train_num_video_frames: int = 57
75
+ max_iter: int = 99
76
+ fps: int = 24
77
+
78
+
79
+ @attrs.define(slots=False)
80
+ class InferenceConfig:
81
+ """
82
+ Inference config
83
+ Args:
84
+ model_config (ModelConfig): Model config
85
+ tokenizer_config (TokenizerConfig): Tokenizer config
86
+ ckpt_path (str): Path to the checkpoint
87
+ latent_shape (list): Shape of the latent
88
+ """
89
+
90
+ model_config: ModelConfig = None
91
+ tokenizer_config: TokenizerConfig = None
92
+ ckpt_path: str = ""
93
+ data_shape_config: DataShapeConfig = None
94
+
95
+ defaults: List[Any] = attrs.field(
96
+ factory=lambda: [
97
+ "_self_",
98
+ {"data_val": None},
99
+ {"data_shape_config": "video_shape_as_model_config"},
100
+ {"eval_job": None},
101
+ ]
102
+ )
cosmos_predict1/autoregressive/configs/registry.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from hydra.core.config_store import ConfigStore
18
+
19
+ from cosmos_predict1.autoregressive.configs.base.callbacks import BASIC_CALLBACKS, VIDEO_TEACHER_FORCING_CALLBACK
20
+ from cosmos_predict1.autoregressive.configs.base.dataloader import get_tealrobot_video
21
+ from cosmos_predict1.autoregressive.configs.base.optim import LambdaLinearLR
22
+ from cosmos_predict1.autoregressive.configs.experiment.video2video.basic import register_experiments
23
+ from cosmos_predict1.utils import config, log
24
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
25
+ from cosmos_predict1.utils.scheduler import WarmupCosineLR
26
+
27
+
28
+ def register_checkpoint(cs):
29
+ checkpoint_local = config.CheckpointConfig(
30
+ save_iter=5000,
31
+ broadcast_via_filesystem=True,
32
+ )
33
+ cs.store(group="checkpoint", package="checkpoint", name="local", node=checkpoint_local)
34
+
35
+
36
+ def register_callbacks(cs):
37
+ cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS)
38
+ cs.store(
39
+ group="callbacks",
40
+ package="trainer.callbacks",
41
+ name="video_teacher_forcing",
42
+ node=VIDEO_TEACHER_FORCING_CALLBACK,
43
+ )
44
+
45
+
46
+ def register_scheduler(cs):
47
+ cs.store(
48
+ group="scheduler",
49
+ package="scheduler",
50
+ name="warmup_cosine_lr",
51
+ node=L(WarmupCosineLR)(optimizer=None, warmup_iters=5000, lr_decay_iters="${trainer.max_iter}", min_lr=1e-8),
52
+ )
53
+ cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearLR)
54
+
55
+
56
+ def register_optimizer(cs):
57
+ cs.store(
58
+ group="optimizer",
59
+ package="optimizer",
60
+ name="fused_adamw",
61
+ node=L(torch.optim.AdamW)(params=None, lr=1e-3, weight_decay=0.05, fused=True),
62
+ )
63
+ cs.store(
64
+ group="optimizer",
65
+ package="optimizer",
66
+ name="sgd",
67
+ node=L(torch.optim.SGD)(params=None, lr=5e-6, momentum=0.9),
68
+ )
69
+
70
+
71
+ def register_training_data(cs):
72
+ cs.store(
73
+ group="data_train",
74
+ package="dataloader_train",
75
+ name="tealrobot_video_small",
76
+ node=get_tealrobot_video(num_frames=33, video_size=[384, 640]),
77
+ )
78
+ cs.store(group="data_train", package="dataloader_train", name="tealrobot_video", node=get_tealrobot_video())
79
+
80
+
81
+ def register_configs():
82
+ log.info("Registering configs for autoregressive_base")
83
+ cs = ConfigStore.instance()
84
+ register_callbacks(cs)
85
+ register_checkpoint(cs)
86
+ register_optimizer(cs)
87
+ register_scheduler(cs)
88
+ register_training_data(cs)
89
+ register_experiments(cs)
cosmos_predict1/autoregressive/datasets/dataset_utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional
17
+
18
+ import torch
19
+ import torchvision.transforms.functional as transforms_F
20
+ from PIL import Image
21
+
22
+
23
+ def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]:
24
+ r"""Function for obtaining the image size from the data dict.
25
+
26
+ Args:
27
+ data_dict (dict): Input data dict
28
+ input_keys (list): List of input keys
29
+ Returns:
30
+ width (int): Width of the input image
31
+ height (int): Height of the input image
32
+ """
33
+
34
+ data1 = data_dict[input_keys[0]]
35
+ if isinstance(data1, Image.Image):
36
+ width, height = data1.size
37
+ elif isinstance(data1, torch.Tensor):
38
+ height, width = data1.size()[-2:]
39
+ else:
40
+ raise ValueError("data to random crop should be PIL Image or tensor")
41
+
42
+ return width, height
43
+
44
+
45
+ class Augmentor:
46
+ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
47
+ r"""Base augmentor class
48
+
49
+ Args:
50
+ input_keys (list): List of input keys
51
+ output_keys (list): List of output keys
52
+ args (dict): Arguments associated with the augmentation
53
+ """
54
+ self.input_keys = input_keys
55
+ self.output_keys = output_keys
56
+ self.args = args
57
+
58
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
59
+ raise ValueError("Augmentor not implemented")
60
+
61
+
62
+ class ResizeSmallestSideAspectPreserving(Augmentor):
63
+ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
64
+ super().__init__(input_keys, output_keys, args)
65
+
66
+ def __call__(self, data_dict: dict) -> dict:
67
+ r"""Performs aspect-ratio preserving resizing.
68
+ Image is resized to the dimension which has the smaller ratio of (size / target_size).
69
+ First we compute (w_img / w_target) and (h_img / h_target) and resize the image
70
+ to the dimension that has the smaller of these ratios.
71
+
72
+ Args:
73
+ data_dict (dict): Input data dict
74
+ Returns:
75
+ data_dict (dict): Output dict where images are resized
76
+ """
77
+
78
+ if self.output_keys is None:
79
+ self.output_keys = self.input_keys
80
+ assert self.args is not None, "Please specify args in augmentations"
81
+
82
+ img_w, img_h = self.args["img_w"], self.args["img_h"]
83
+
84
+ orig_w, orig_h = obtain_image_size(data_dict, self.input_keys)
85
+ scaling_ratio = max((img_w / orig_w), (img_h / orig_h))
86
+ target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5))
87
+
88
+ assert (
89
+ target_size[0] >= img_h and target_size[1] >= img_w
90
+ ), f"Resize error. orig {(orig_w, orig_h)} desire {(img_w, img_h)} compute {target_size}"
91
+
92
+ for inp_key, out_key in zip(self.input_keys, self.output_keys):
93
+ data_dict[out_key] = transforms_F.resize(
94
+ data_dict[inp_key],
95
+ size=target_size, # type: ignore
96
+ interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC),
97
+ antialias=True,
98
+ )
99
+
100
+ if out_key != inp_key:
101
+ del data_dict[inp_key]
102
+ return data_dict
103
+
104
+
105
+ class CenterCrop(Augmentor):
106
+ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
107
+ super().__init__(input_keys, output_keys, args)
108
+
109
+ def __call__(self, data_dict: dict) -> dict:
110
+ r"""Performs center crop.
111
+
112
+ Args:
113
+ data_dict (dict): Input data dict
114
+ Returns:
115
+ data_dict (dict): Output dict where images are center cropped.
116
+ We also save the cropping parameters in the aug_params dict
117
+ so that it will be used by other transforms.
118
+ """
119
+ assert (
120
+ (self.args is not None) and ("img_w" in self.args) and ("img_h" in self.args)
121
+ ), "Please specify size in args"
122
+
123
+ img_w, img_h = self.args["img_w"], self.args["img_h"]
124
+
125
+ orig_w, orig_h = obtain_image_size(data_dict, self.input_keys)
126
+ for key in self.input_keys:
127
+ data_dict[key] = transforms_F.center_crop(data_dict[key], [img_h, img_w])
128
+
129
+ # We also add the aug params we use. This will be useful for other transforms
130
+ crop_x0 = (orig_w - img_w) // 2
131
+ crop_y0 = (orig_h - img_h) // 2
132
+ cropping_params = {
133
+ "resize_w": orig_w,
134
+ "resize_h": orig_h,
135
+ "crop_x0": crop_x0,
136
+ "crop_y0": crop_y0,
137
+ "crop_w": img_w,
138
+ "crop_h": img_h,
139
+ }
140
+
141
+ if "aug_params" not in data_dict:
142
+ data_dict["aug_params"] = dict()
143
+
144
+ data_dict["aug_params"]["cropping"] = cropping_params
145
+ data_dict["padding_mask"] = torch.zeros((1, cropping_params["crop_h"], cropping_params["crop_w"]))
146
+ return data_dict
147
+
148
+
149
+ class Normalize(Augmentor):
150
+ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
151
+ super().__init__(input_keys, output_keys, args)
152
+
153
+ def __call__(self, data_dict: dict) -> dict:
154
+ r"""Performs data normalization.
155
+
156
+ Args:
157
+ data_dict (dict): Input data dict
158
+ Returns:
159
+ data_dict (dict): Output dict where images are center cropped.
160
+ """
161
+ assert self.args is not None, "Please specify args"
162
+
163
+ mean = self.args["mean"]
164
+ std = self.args["std"]
165
+
166
+ for key in self.input_keys:
167
+ if isinstance(data_dict[key], torch.Tensor):
168
+ data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255)
169
+ else:
170
+ data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor()
171
+
172
+ data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std)
173
+ return data_dict
cosmos_predict1/autoregressive/datasets/video_dataset.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Run this command to interactively debug:
18
+ PYTHONPATH=. python cosmos_predict1/autoregressive/datasets/video_dataset.py
19
+ """
20
+
21
+ import os
22
+ import traceback
23
+ import warnings
24
+ from concurrent.futures import ThreadPoolExecutor, as_completed
25
+
26
+ import numpy as np
27
+ import torch
28
+ from decord import VideoReader, cpu
29
+ from torch.utils.data import Dataset
30
+ from tqdm import tqdm
31
+
32
+ from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig
33
+ from cosmos_predict1.autoregressive.datasets.dataset_utils import (
34
+ CenterCrop,
35
+ Normalize,
36
+ ResizeSmallestSideAspectPreserving,
37
+ )
38
+
39
+
40
+ class VideoDataset(Dataset):
41
+ def __init__(self, config: VideoDatasetConfig):
42
+ """Video Dataset class for loading video-to-video generation data."""
43
+
44
+ super().__init__()
45
+ self.dataset_dir = config.dataset_dir
46
+ self.sequence_interval = config.sequence_interval
47
+ self.sequence_length = config.num_frames
48
+ self.video_size = config.video_size
49
+ self.start_frame_interval = config.start_frame_interval
50
+
51
+ self.video_dir = self.dataset_dir
52
+ self.video_paths = [os.path.join(self.video_dir, f) for f in os.listdir(self.video_dir) if f.endswith(".mp4")]
53
+ print(f"{len(self.video_paths)} videos in total")
54
+
55
+ self.samples = self._init_samples(self.video_paths)
56
+ self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0]))
57
+ print(f"{len(self.samples)} samples in total")
58
+ self.wrong_number = 0
59
+
60
+ self.resize_transform = ResizeSmallestSideAspectPreserving(
61
+ input_keys=["video"],
62
+ args={"img_w": self.video_size[1], "img_h": self.video_size[0]},
63
+ )
64
+ self.crop_transform = CenterCrop(
65
+ input_keys=["video"],
66
+ args={"img_w": self.video_size[1], "img_h": self.video_size[0]},
67
+ )
68
+ self.normalize_transform = Normalize(
69
+ input_keys=["video"],
70
+ args={"mean": 0.5, "std": 0.5},
71
+ )
72
+
73
+ def __str__(self):
74
+ return f"{len(self.video_paths)} samples from {self.dataset_dir}"
75
+
76
+ def _init_samples(self, video_paths):
77
+ samples = []
78
+ with ThreadPoolExecutor(32) as executor:
79
+ future_to_video_path = {
80
+ executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths
81
+ }
82
+ for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)):
83
+ samples.extend(future.result())
84
+ return samples
85
+
86
+ def _load_and_process_video_path(self, video_path):
87
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
88
+ n_frames = len(vr)
89
+
90
+ samples = []
91
+ for frame_i in range(0, n_frames, self.start_frame_interval):
92
+ sample = dict()
93
+ sample["video_path"] = video_path
94
+ sample["orig_num_frames"] = n_frames
95
+ sample["chunk_index"] = -1
96
+ sample["frame_ids"] = []
97
+ curr_frame_i = frame_i
98
+ while True:
99
+ if curr_frame_i > (n_frames - 1):
100
+ break
101
+ sample["frame_ids"].append(curr_frame_i)
102
+ if len(sample["frame_ids"]) == self.sequence_length:
103
+ break
104
+ curr_frame_i += self.sequence_interval
105
+ # make sure there are sequence_length number of frames
106
+ if len(sample["frame_ids"]) == self.sequence_length:
107
+ sample["chunk_index"] += 1
108
+ samples.append(sample)
109
+ return samples
110
+
111
+ def __len__(self):
112
+ return len(self.samples)
113
+
114
+ def _load_video(self, video_path, frame_ids):
115
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
116
+ assert (np.array(frame_ids) < len(vr)).all(), "Some frame_ids are out of range."
117
+ assert (np.array(frame_ids) >= 0).all(), "Some frame_ids are negative."
118
+ vr.seek(0)
119
+ frame_data = vr.get_batch(frame_ids).asnumpy()
120
+ fps = vr.get_avg_fps()
121
+ return frame_data, fps
122
+
123
+ def _get_frames(self, video_path, frame_ids):
124
+ frames, fps = self._load_video(video_path, frame_ids)
125
+ frames = frames.astype(np.uint8)
126
+ frames = torch.from_numpy(frames)
127
+ frames = frames.permute(0, 3, 1, 2) # Rearrange from [T, H, W, C] to [T, C, H, W]
128
+ return frames, fps
129
+
130
+ def __getitem__(self, index):
131
+ try:
132
+ sample = self.samples[index]
133
+ video_path = sample["video_path"]
134
+ frame_ids = sample["frame_ids"]
135
+
136
+ data = dict()
137
+
138
+ video, fps = self._get_frames(video_path, frame_ids)
139
+ data["video"] = video
140
+ data["fps"] = fps
141
+ data["num_frames"] = self.sequence_length
142
+ data["orig_num_frames"] = sample["orig_num_frames"]
143
+ data["chunk_index"] = sample["chunk_index"]
144
+ data["frame_start"] = frame_ids[0]
145
+ data["frame_end"] = frame_ids[-1]
146
+
147
+ data["video_name"] = {
148
+ "video_path": video_path,
149
+ "start_frame_id": str(frame_ids[0]),
150
+ }
151
+
152
+ # resize video to smallest side aspect preserving
153
+ data = self.resize_transform(data)
154
+ # center crop video
155
+ data = self.crop_transform(data)
156
+ # normalize video
157
+ data = self.normalize_transform(data)
158
+
159
+ data["video"] = data["video"].permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W]
160
+
161
+ return data
162
+ except Exception:
163
+ warnings.warn(
164
+ f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped "
165
+ f"(by randomly sampling another sample in the same dataset)."
166
+ )
167
+ warnings.warn("FULL TRACEBACK:")
168
+ warnings.warn(traceback.format_exc())
169
+ self.wrong_number += 1
170
+ print(self.wrong_number)
171
+ return self[np.random.randint(len(self.samples))]
172
+
173
+
174
+ if __name__ == "__main__":
175
+ config = VideoDatasetConfig(dataset_dir="datasets/cosmos_nemo_assets/videos/")
176
+ dataset = VideoDataset(config)
177
+
178
+ indices = [0, 1, 2, -1]
179
+ for idx in indices:
180
+ data = dataset[idx]
181
+ print(
182
+ (
183
+ f"{idx=} "
184
+ f"{data['video'].sum()=}\n"
185
+ f"{data['video'].shape=}\n"
186
+ f"{data['video_name']=}\n"
187
+ f"{data.keys()=}\n"
188
+ "---"
189
+ )
190
+ )
cosmos_predict1/autoregressive/diffusion_decoder/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Dict, Optional
18
+
19
+ import torch
20
+
21
+ from cosmos_predict1.diffusion.conditioner import BaseVideoCondition, GeneralConditioner
22
+ from cosmos_predict1.diffusion.config.base.conditioner import (
23
+ FPSConfig,
24
+ ImageSizeConfig,
25
+ LatentConditionConfig,
26
+ LatentConditionSigmaConfig,
27
+ NumFramesConfig,
28
+ PaddingMaskConfig,
29
+ TextConfig,
30
+ )
31
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
32
+ from cosmos_predict1.utils.lazy_config import LazyDict
33
+
34
+
35
+ @dataclass
36
+ class VideoLatentDiffusionDecoderCondition(BaseVideoCondition):
37
+ # latent_condition will concat to the input of network, along channel dim;
38
+ # cfg will make latent_condition all zero padding.
39
+ latent_condition: Optional[torch.Tensor] = None
40
+ latent_condition_sigma: Optional[torch.Tensor] = None
41
+
42
+
43
+ class VideoDiffusionDecoderConditioner(GeneralConditioner):
44
+ def forward(
45
+ self,
46
+ batch: Dict,
47
+ override_dropout_rate: Optional[Dict[str, float]] = None,
48
+ ) -> VideoLatentDiffusionDecoderCondition:
49
+ output = super()._forward(batch, override_dropout_rate)
50
+ return VideoLatentDiffusionDecoderCondition(**output)
51
+
52
+
53
+ VideoLatentDiffusionDecoderConditionerConfig: LazyDict = L(VideoDiffusionDecoderConditioner)(
54
+ text=TextConfig(),
55
+ fps=FPSConfig(),
56
+ num_frames=NumFramesConfig(),
57
+ image_size=ImageSizeConfig(),
58
+ padding_mask=PaddingMaskConfig(),
59
+ latent_condition=LatentConditionConfig(),
60
+ latent_condition_sigma=LatentConditionSigmaConfig(),
61
+ )
cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, List
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.autoregressive.diffusion_decoder.config.registry import register_configs as register_dd_configs
21
+ from cosmos_predict1.diffusion.config.base.model import LatentDiffusionDecoderModelConfig
22
+ from cosmos_predict1.diffusion.config.registry import register_configs
23
+ from cosmos_predict1.utils import config
24
+ from cosmos_predict1.utils.config_helper import import_all_modules_from_package
25
+
26
+
27
+ @attrs.define(slots=False)
28
+ class Config(config.Config):
29
+ # default config groups that will be used unless overwritten
30
+ # see config groups in registry.py
31
+ defaults: List[Any] = attrs.field(
32
+ factory=lambda: [
33
+ "_self_",
34
+ {"net": None},
35
+ {"conditioner": "basic"},
36
+ {"tokenizer": "tokenizer"},
37
+ {"tokenizer_corruptor": None},
38
+ {"latent_corruptor": None},
39
+ {"pixel_corruptor": None},
40
+ {"experiment": None},
41
+ ]
42
+ )
43
+
44
+
45
+ def make_config():
46
+ c = Config(model=LatentDiffusionDecoderModelConfig())
47
+
48
+ # Specifying values through instances of attrs
49
+ c.job.project = "cosmos_video4"
50
+ c.job.group = "debug"
51
+ c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}"
52
+
53
+ # Call this function to register config groups for advanced overriding.
54
+ register_configs()
55
+ register_dd_configs()
56
+
57
+ # experiment config are defined in the experiment folder
58
+ # call import_all_modules_from_package to register them
59
+ import_all_modules_from_package("cosmos_predict1.diffusion.config.inference", reload=True)
60
+ import_all_modules_from_package("cosmos_predict1.autoregressive.diffusion_decoder.config.inference", reload=True)
61
+ return c
cosmos_predict1/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from hydra.core.config_store import ConfigStore
17
+
18
+ from cosmos_predict1.autoregressive.diffusion_decoder.network import DiffusionDecoderGeneralDIT
19
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
20
+ from cosmos_predict1.utils.lazy_config import LazyDict
21
+
22
+ num_frames = 57
23
+ Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY: LazyDict = LazyDict(
24
+ dict(
25
+ defaults=[
26
+ {"override /net": "faditv2_7b"},
27
+ {"override /tokenizer": "cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624"},
28
+ {"override /conditioner": "video_latent_diffusion_decoder_cond"},
29
+ {"override /tokenizer_corruptor": "cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224"},
30
+ "_self_",
31
+ ],
32
+ job=dict(
33
+ group="diffusion_deocder_FT_7Bv1_001",
34
+ name="DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token",
35
+ ),
36
+ model=dict(
37
+ diffusion_decoder_cond_sigma_low=0.0,
38
+ diffusion_decoder_cond_sigma_high=0.0,
39
+ diffusion_decoder_corrupt_prob=0.0,
40
+ condition_on_tokenizer_corruptor_token=True,
41
+ latent_shape=[
42
+ 16,
43
+ num_frames,
44
+ 88,
45
+ 160,
46
+ ],
47
+ tokenizer_corruptor=dict(
48
+ pixel_chunk_duration=num_frames,
49
+ latent_chunk_duration=1 + (num_frames - 1) // 8,
50
+ ),
51
+ net=L(DiffusionDecoderGeneralDIT)(
52
+ diffusion_decoder_condition_on_sigma=False,
53
+ max_img_h=240,
54
+ max_img_w=240,
55
+ rope_h_extrapolation_ratio=1.5,
56
+ rope_w_extrapolation_ratio=1.5,
57
+ rope_t_extrapolation_ratio=1,
58
+ block_x_format="THWBD",
59
+ is_diffusion_decoder=True,
60
+ patch_spatial=2,
61
+ diffusion_decoder_condition_on_token=True,
62
+ diffusion_decoder_token_condition_voc_size=64000,
63
+ diffusion_decoder_token_condition_dim=32,
64
+ ),
65
+ tokenizer=dict(
66
+ video_vae=dict(
67
+ pixel_chunk_duration=num_frames,
68
+ )
69
+ ),
70
+ conditioner=dict(
71
+ latent_condition=dict(
72
+ dropout_rate=0.2,
73
+ )
74
+ ),
75
+ ),
76
+ )
77
+ )
78
+
79
+ cs = ConfigStore.instance()
80
+ cs.store(
81
+ group="experiment",
82
+ package="_global_",
83
+ name=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY["job"]["name"],
84
+ node=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY,
85
+ )
cosmos_predict1/autoregressive/diffusion_decoder/config/registry.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from hydra.core.config_store import ConfigStore
17
+
18
+ from cosmos_predict1.autoregressive.diffusion_decoder.config.base.conditioner import (
19
+ VideoLatentDiffusionDecoderConditionerConfig,
20
+ )
21
+ from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQJITTokenizer
22
+ from cosmos_predict1.diffusion.module.pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer
23
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
24
+
25
+
26
+ def get_cosmos_video_discrete_tokenizer_comp8x16x16(
27
+ resolution: str,
28
+ chunk_duration: int,
29
+ checkpoint_path: str,
30
+ ):
31
+ assert resolution in ["720"]
32
+
33
+ pixel_chunk_duration = chunk_duration
34
+ temporal_compression_factor = 8
35
+ spatial_compression_factor = 16
36
+
37
+ return L(DiscreteVideoFSQJITTokenizer)(
38
+ enc_fp=checkpoint_path.replace(".jit", "encoder.jit"),
39
+ dec_fp=checkpoint_path.replace(".jit", "decoder.jit"),
40
+ name="discrete_video_fsq",
41
+ latent_ch=6,
42
+ is_bf16=True,
43
+ pixel_chunk_duration=pixel_chunk_duration,
44
+ latent_chunk_duration=1 + (pixel_chunk_duration - 1) // temporal_compression_factor,
45
+ max_enc_batch_size=8,
46
+ max_dec_batch_size=4,
47
+ levels=[8, 8, 8, 5, 5, 5],
48
+ compression_ratio=[temporal_compression_factor, spatial_compression_factor, spatial_compression_factor],
49
+ )
50
+
51
+
52
+ def get_cosmos_video_tokenizer_comp8x8x8(resolution: str, chunk_duration: int, checkpoint_path=None):
53
+ pixel_chunk_duration = chunk_duration
54
+ temporal_compression_factor = 8
55
+ spatial_compression_factor = 8
56
+
57
+ return L(JointImageVideoSharedJITTokenizer)(
58
+ video_vae=L(VideoJITTokenizer)(
59
+ name="cosmos_predict1_tokenizer",
60
+ latent_ch=16,
61
+ is_bf16=True,
62
+ pixel_chunk_duration=pixel_chunk_duration,
63
+ temporal_compression_factor=temporal_compression_factor,
64
+ spatial_compression_factor=spatial_compression_factor,
65
+ spatial_resolution=resolution,
66
+ ),
67
+ image_vae=L(JITVAE)(
68
+ name="cosmos_predict1_tokenizer",
69
+ latent_ch=16,
70
+ is_image=False,
71
+ is_bf16=True,
72
+ ),
73
+ name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624",
74
+ latent_ch=16,
75
+ )
76
+
77
+
78
+ def register_tokenizer(cs):
79
+ cs.store(
80
+ group="tokenizer",
81
+ package="model.tokenizer",
82
+ name="cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624",
83
+ node=get_cosmos_video_tokenizer_comp8x8x8(
84
+ resolution="720",
85
+ chunk_duration=121,
86
+ checkpoint_path="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/.jit",
87
+ ),
88
+ )
89
+
90
+
91
+ def register_corruptor(cs):
92
+ cs.store(
93
+ group="tokenizer_corruptor",
94
+ package="model.tokenizer_corruptor",
95
+ name="cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224",
96
+ node=get_cosmos_video_discrete_tokenizer_comp8x16x16(
97
+ resolution="720",
98
+ chunk_duration=49,
99
+ checkpoint_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/.jit",
100
+ ),
101
+ )
102
+
103
+
104
+ def register_conditioner(cs):
105
+ cs.store(
106
+ group="conditioner",
107
+ package="model.conditioner",
108
+ name="video_latent_diffusion_decoder_cond",
109
+ node=VideoLatentDiffusionDecoderConditionerConfig,
110
+ )
111
+
112
+
113
+ def register_configs():
114
+ cs = ConfigStore.instance()
115
+
116
+ register_conditioner(cs)
117
+ register_corruptor(cs)
118
+ register_tokenizer(cs)