Spaces:
Build error
Build error
Upload 381 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +22 -0
- assets/demo_1.gif +3 -0
- assets/demo_2.gif +3 -0
- assets/demo_3.gif +3 -0
- assets/demo_dynamic.gif +3 -0
- assets/diffusion/000000.png +3 -0
- assets/diffusion/000001.png +3 -0
- assets/diffusion/000002.png +3 -0
- assets/diffusion/000003.png +3 -0
- assets/diffusion/000004.png +3 -0
- assets/diffusion/000005.png +3 -0
- assets/diffusion/000006.png +3 -0
- assets/diffusion/000007.png +3 -0
- assets/diffusion/000008.png +3 -0
- assets/diffusion/000009.png +3 -0
- assets/diffusion/000010.png +3 -0
- assets/diffusion/000011.png +3 -0
- assets/diffusion/000012.png +3 -0
- assets/diffusion/000013.png +3 -0
- assets/diffusion/000014.png +3 -0
- assets/diffusion/000015.png +3 -0
- checkpoints/.DS_Store +0 -0
- checkpoints/README.md +4 -0
- cosmos-predict1.yaml +29 -0
- cosmos_predict1/.DS_Store +0 -0
- cosmos_predict1/__init__.py +14 -0
- cosmos_predict1/autoregressive/__init__.py +14 -0
- cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py +352 -0
- cosmos_predict1/autoregressive/configs/__init__.py +14 -0
- cosmos_predict1/autoregressive/configs/base/__init__.py +14 -0
- cosmos_predict1/autoregressive/configs/base/callbacks.py +33 -0
- cosmos_predict1/autoregressive/configs/base/dataloader.py +72 -0
- cosmos_predict1/autoregressive/configs/base/dataset.py +39 -0
- cosmos_predict1/autoregressive/configs/base/model.py +318 -0
- cosmos_predict1/autoregressive/configs/base/model_config.py +718 -0
- cosmos_predict1/autoregressive/configs/base/model_parallel.py +33 -0
- cosmos_predict1/autoregressive/configs/base/optim.py +86 -0
- cosmos_predict1/autoregressive/configs/base/tokenizer.py +139 -0
- cosmos_predict1/autoregressive/configs/config.py +111 -0
- cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py +0 -0
- cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py +163 -0
- cosmos_predict1/autoregressive/configs/inference/inference_config.py +102 -0
- cosmos_predict1/autoregressive/configs/registry.py +89 -0
- cosmos_predict1/autoregressive/datasets/dataset_utils.py +173 -0
- cosmos_predict1/autoregressive/datasets/video_dataset.py +190 -0
- cosmos_predict1/autoregressive/diffusion_decoder/__init__.py +14 -0
- cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py +61 -0
- cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py +61 -0
- cosmos_predict1/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py +85 -0
- cosmos_predict1/autoregressive/diffusion_decoder/config/registry.py +118 -0
.gitattributes
CHANGED
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/demo_1.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/demo_2.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/demo_3.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/demo_dynamic.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/diffusion/000000.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/diffusion/000001.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/diffusion/000002.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/diffusion/000003.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/diffusion/000004.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/diffusion/000005.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/diffusion/000006.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
assets/diffusion/000007.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
assets/diffusion/000008.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
assets/diffusion/000009.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
assets/diffusion/000010.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
assets/diffusion/000011.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
assets/diffusion/000012.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
assets/diffusion/000013.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
assets/diffusion/000014.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
assets/diffusion/000015.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
cosmos_predict1/tokenizer/test_data/image.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
cosmos_predict1/tokenizer/test_data/video.mp4 filter=lfs diff=lfs merge=lfs -text
|
assets/demo_1.gif
ADDED
![]() |
Git LFS Details
|
assets/demo_2.gif
ADDED
![]() |
Git LFS Details
|
assets/demo_3.gif
ADDED
![]() |
Git LFS Details
|
assets/demo_dynamic.gif
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000000.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000001.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000002.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000003.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000004.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000005.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000006.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000007.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000008.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000009.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000010.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000011.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000012.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000013.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000014.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000015.png
ADDED
![]() |
Git LFS Details
|
checkpoints/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
checkpoints/README.md
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
### Checkpoint directory
|
3 |
+
|
4 |
+
Model checkpoints will be downloaded to this directory.
|
cosmos-predict1.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# conda env create --file cosmos-predict1.yaml
|
17 |
+
name: cosmos-predict1
|
18 |
+
channels:
|
19 |
+
- conda-forge
|
20 |
+
dependencies:
|
21 |
+
- python=3.10
|
22 |
+
- pip=25.0
|
23 |
+
- cmake
|
24 |
+
- ninja
|
25 |
+
- gcc=12.4.0
|
26 |
+
- gxx=12.4.0
|
27 |
+
- cuda=12.4
|
28 |
+
- cuda-nvcc=12.4
|
29 |
+
- cuda-toolkit=12.4
|
cosmos_predict1/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
cosmos_predict1/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_predict1/autoregressive/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import glob
|
17 |
+
import math
|
18 |
+
import os
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torchvision
|
24 |
+
import torchvision.transforms.functional as torchvision_F
|
25 |
+
import wandb
|
26 |
+
from einops import rearrange
|
27 |
+
from megatron.core import parallel_state
|
28 |
+
from torch.distributed import get_process_group_ranks
|
29 |
+
|
30 |
+
from cosmos_predict1.autoregressive.utils.parallel import (
|
31 |
+
broadcast_data_batch_in_tp_cp_group,
|
32 |
+
gather_batch_from_cp_ranks,
|
33 |
+
get_batch_on_this_cp_rank,
|
34 |
+
)
|
35 |
+
from cosmos_predict1.callbacks.every_n import EveryN
|
36 |
+
from cosmos_predict1.utils import distributed, log, misc
|
37 |
+
from cosmos_predict1.utils.model import Model
|
38 |
+
from cosmos_predict1.utils.trainer import Trainer
|
39 |
+
|
40 |
+
|
41 |
+
def resize_image(image: torch.Tensor, resize_factor=0.5) -> torch.Tensor:
|
42 |
+
_, _, h, w = image.shape
|
43 |
+
new_h, new_w = int(resize_factor * h), int(resize_factor * w)
|
44 |
+
return torchvision_F.resize(image, (new_h, new_w))
|
45 |
+
|
46 |
+
|
47 |
+
class VideoSamplingTeacherForcing(EveryN):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
every_n: int,
|
51 |
+
step_size: int = 1,
|
52 |
+
video_latent_shape: list = [6, 24, 40],
|
53 |
+
num_frames_to_display: int = 4,
|
54 |
+
save_folder: Optional[str] = None,
|
55 |
+
num_file_to_log: int = 8,
|
56 |
+
):
|
57 |
+
r"""
|
58 |
+
This callback enables us to perform teacher forcing inference on the training data.
|
59 |
+
By teacher forcing, we mean providing ground truth video tokens as inputs, and simply asking the model
|
60 |
+
to predict the next tokens. The predicted next tokens are then visualized. This does not perform
|
61 |
+
autoregressive sampling.
|
62 |
+
We also upload the downsampled video frames to wandb. Downsampling is needed for wandb to work fast.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
every_n (int): Call this callback every_n steps
|
66 |
+
step_size (int): Number of steps taken for gradient accumulation. Global iteration number is
|
67 |
+
iteration // self.step_size
|
68 |
+
video_latent_shape (list): Shape of the video latent
|
69 |
+
num_frames_to_display (int): Number of frames to subsample for displaying in wandb
|
70 |
+
save_folder (str): Name of the local folder to save the video
|
71 |
+
num_file_to_log (int): Number of files to upload to wandb
|
72 |
+
"""
|
73 |
+
super().__init__(every_n, step_size)
|
74 |
+
self.save_folder = save_folder if save_folder else self.__class__.__name__
|
75 |
+
self.video_latent_shape = video_latent_shape
|
76 |
+
self.num_frames_to_display = num_frames_to_display
|
77 |
+
self.num_file_to_log = num_file_to_log
|
78 |
+
self.rank = distributed.get_rank()
|
79 |
+
|
80 |
+
def on_train_start(self, model: Model, iteration: int = 0) -> None:
|
81 |
+
config_job = self.config.job
|
82 |
+
self.local_dir = f"{config_job.path_local}/{self.save_folder}"
|
83 |
+
if self.rank == 0:
|
84 |
+
os.makedirs(self.local_dir, exist_ok=True)
|
85 |
+
log.info(f"Video Teacher-Forcing Callback: local_dir: {self.local_dir}")
|
86 |
+
|
87 |
+
@torch.inference_mode()
|
88 |
+
def every_n_impl(
|
89 |
+
self,
|
90 |
+
trainer: Trainer,
|
91 |
+
model: Model,
|
92 |
+
data_batch: dict[str, torch.Tensor],
|
93 |
+
output_batch: dict[str, torch.Tensor],
|
94 |
+
loss: torch.Tensor,
|
95 |
+
iteration: int,
|
96 |
+
) -> None:
|
97 |
+
# Tokenize the data
|
98 |
+
|
99 |
+
broadcast_data_batch_in_tp_cp_group(data_batch)
|
100 |
+
|
101 |
+
input_vid = data_batch[model.tokenizer.tokenizer_config.video_tokenizer.data_key]
|
102 |
+
|
103 |
+
dataset_name = data_batch.get("dataset_name", None)
|
104 |
+
if dataset_name is not None and dataset_name.startswith("image"):
|
105 |
+
# we disable the callback if the input video is an image batch
|
106 |
+
log.info(f"dataset_name is {dataset_name}, skip this callback")
|
107 |
+
return
|
108 |
+
|
109 |
+
# get the caption
|
110 |
+
captions = data_batch.get("caption", None)
|
111 |
+
|
112 |
+
# get the context embedding and mask
|
113 |
+
context = data_batch.get("context", None)
|
114 |
+
context_mask = data_batch.get("context_mask", None)
|
115 |
+
if context is not None:
|
116 |
+
context = misc.to(context, "cuda").detach().clone()
|
117 |
+
if context_mask is not None:
|
118 |
+
context_mask = misc.to(context_mask, "cuda").detach().clone()
|
119 |
+
# get the action
|
120 |
+
action = data_batch.get("action", None)
|
121 |
+
if action is not None:
|
122 |
+
action = misc.to(action, "cuda").detach().clone()
|
123 |
+
|
124 |
+
# Input tokens
|
125 |
+
tokens, _ = model.tokenizer.tokenize(data_batch)
|
126 |
+
tokens = misc.to(tokens, "cuda").detach().clone()
|
127 |
+
skip_save_file = False
|
128 |
+
if parallel_state.get_context_parallel_world_size() > 1:
|
129 |
+
cp_group = parallel_state.get_context_parallel_group()
|
130 |
+
if self.rank != min(get_process_group_ranks(cp_group)):
|
131 |
+
skip_save_file = True
|
132 |
+
tokens = get_batch_on_this_cp_rank(tokens)
|
133 |
+
if parallel_state.get_tensor_model_parallel_world_size() > 1:
|
134 |
+
# Turn on TP
|
135 |
+
tp_group = parallel_state.get_tensor_model_parallel_group()
|
136 |
+
if self.rank != min(get_process_group_ranks(tp_group)):
|
137 |
+
skip_save_file = True
|
138 |
+
tokens_encoded_in_train = output_batch["encode_tokens"].detach()
|
139 |
+
percent_token_diff = (tokens != tokens_encoded_in_train).float().mean()
|
140 |
+
percent_token_diff = distributed.dist_reduce_tensor(percent_token_diff)
|
141 |
+
|
142 |
+
input_tokens = tokens
|
143 |
+
|
144 |
+
num_tokens_to_generate = np.prod(self.video_latent_shape)
|
145 |
+
|
146 |
+
# Do a forward pass
|
147 |
+
logits = model.model.forward(
|
148 |
+
tokens,
|
149 |
+
input_pos=None,
|
150 |
+
context=context,
|
151 |
+
context_mask=context_mask,
|
152 |
+
action=action,
|
153 |
+
)
|
154 |
+
if parallel_state.get_context_parallel_world_size() > 1:
|
155 |
+
logits = gather_batch_from_cp_ranks(logits)
|
156 |
+
input_tokens = gather_batch_from_cp_ranks(input_tokens)
|
157 |
+
|
158 |
+
# Start position for video tokens in the vocabulary
|
159 |
+
video_token_start = self.config.model.tokenizer_config.video_tokenizer.tokenizer_offset
|
160 |
+
video_vocab_size = self.config.model.tokenizer_config.video_tokenizer.vocab_size
|
161 |
+
|
162 |
+
# Clipping logits only to video tokens. We remove the text vocab predictions.
|
163 |
+
# This will ensure that the video tokens only correspond to the video part of the vocabulary.
|
164 |
+
logits = logits[:, :, video_token_start : video_token_start + video_vocab_size]
|
165 |
+
|
166 |
+
# Sample with argmax token. This should be good for teacher forcing experiment.
|
167 |
+
logits = logits.contiguous()
|
168 |
+
generations = torch.argmax(logits, dim=-1)
|
169 |
+
|
170 |
+
# For each video in the batch, subsample frames for display
|
171 |
+
batch_size = input_tokens.shape[0]
|
172 |
+
out_frames = []
|
173 |
+
out_videos_gen = []
|
174 |
+
out_videos_rec = []
|
175 |
+
out_videos_gt = []
|
176 |
+
# log the accuracy of teacher-forcing
|
177 |
+
acc = []
|
178 |
+
loss_list = []
|
179 |
+
|
180 |
+
for sample_num in range(batch_size):
|
181 |
+
# Subsample the generations to the video part.
|
182 |
+
# This corresponds to the part from begin of video to end of video.
|
183 |
+
bov_token = model.tokenizer.video_special_tokens["<|begin_of_video|>"]
|
184 |
+
bov_index = input_tokens[sample_num] == bov_token
|
185 |
+
use_special_token = sum(bov_index) != 0
|
186 |
+
if use_special_token:
|
187 |
+
bov_index = bov_index.nonzero().item()
|
188 |
+
# generations: <bov> real_token1 real_token2, ... real_token7680; total 7680
|
189 |
+
# gen_video_tokens: real_token1 real_token2, ..., real_token7680; total 7680
|
190 |
+
# for vis: real_token1 real_token2, ..., real_token7680; total 7680
|
191 |
+
# for accuracy: real_token1 real_token2, ..., real_token7680; total 7680
|
192 |
+
gen_video_tokens = generations[sample_num][bov_index : bov_index + num_tokens_to_generate]
|
193 |
+
gen_video_tokens_vis = gen_video_tokens
|
194 |
+
gen_video_tokens_acc = gen_video_tokens
|
195 |
+
logits_loss = logits[sample_num][bov_index : bov_index + num_tokens_to_generate]
|
196 |
+
else:
|
197 |
+
# generations: real_token1 real_token2, ... real_token7680
|
198 |
+
# gen_video_tokens: real_token2 real_token3, ..., real_token7680; total 7679
|
199 |
+
# We need different tokens for vis and accuracy compute
|
200 |
+
# for acc: real_token2 real_token3, ..., real_token7680; total 7679
|
201 |
+
# for vis: pad_token (real_token2, ..., real_token7680); total 1 + 7679
|
202 |
+
gen_video_tokens = generations[sample_num][
|
203 |
+
: num_tokens_to_generate - 1
|
204 |
+
] # remove the last token since there is no gt
|
205 |
+
# Since the first token is not predicted, we need to add the gt first token to make sure the shape is correct
|
206 |
+
gen_video_tokens_vis = torch.cat([input_tokens[sample_num][0:1], gen_video_tokens])
|
207 |
+
gen_video_tokens_acc = gen_video_tokens
|
208 |
+
logits_loss = logits[sample_num][: num_tokens_to_generate - 1]
|
209 |
+
|
210 |
+
# Rearrange the video to a spatial tensor
|
211 |
+
gen_video_tokens_vis_BTHW = rearrange(
|
212 |
+
gen_video_tokens_vis.unsqueeze(0),
|
213 |
+
"B (T H W) -> B T H W",
|
214 |
+
T=self.video_latent_shape[0],
|
215 |
+
H=self.video_latent_shape[1],
|
216 |
+
W=self.video_latent_shape[2],
|
217 |
+
)
|
218 |
+
|
219 |
+
# for real videos, we need to skip the bov and eov tokens for decoding
|
220 |
+
if use_special_token:
|
221 |
+
# input_tokens: <bov> real_token1 real_token2 ... <eov> <eov> ...
|
222 |
+
# real_video_tokens: real_token1 real_token2 ... real_token7680; total 7680
|
223 |
+
# for vis: real_token1 real_token2 ... real_token7680; total 7680
|
224 |
+
# for accuracy: real_token1 real_token2 ... real_token7680; total 7680; we include real_token1 since the output prediction also includes it, see gen_video_tokens_acc above
|
225 |
+
real_video_tokens = (
|
226 |
+
input_tokens[sample_num][bov_index + 1 : bov_index + num_tokens_to_generate + 1] - video_token_start
|
227 |
+
)
|
228 |
+
real_video_tokens_vis = real_video_tokens
|
229 |
+
real_video_tokens_acc = real_video_tokens
|
230 |
+
else:
|
231 |
+
# input_tokens: real_token1 real_token2,... real_token7680; total 7680
|
232 |
+
# real_video_tokens: real_token1 real_token2,... real_token7680; total 7680
|
233 |
+
# for acc: gt start from real_token2, real_token3; total 7679, remove the first token since it is not predicted
|
234 |
+
# for vis: gt start from real_token1, real_token2; total 7680
|
235 |
+
real_video_tokens = input_tokens[sample_num][:num_tokens_to_generate] - video_token_start
|
236 |
+
real_video_tokens_vis = real_video_tokens
|
237 |
+
real_video_tokens_acc = real_video_tokens[1:].flatten()
|
238 |
+
|
239 |
+
real_video_tokens_vis_BTHW = rearrange(
|
240 |
+
real_video_tokens_vis.unsqueeze(0),
|
241 |
+
"B (T H W) -> B T H W",
|
242 |
+
T=self.video_latent_shape[0],
|
243 |
+
H=self.video_latent_shape[1],
|
244 |
+
W=self.video_latent_shape[2],
|
245 |
+
)
|
246 |
+
# Calculate accuracy
|
247 |
+
correct_predictions = (gen_video_tokens_acc == real_video_tokens_acc).float()
|
248 |
+
labels = real_video_tokens_acc.clone()
|
249 |
+
|
250 |
+
if model.config.ignore_first_num_tokens > 0:
|
251 |
+
labels[: model.config.ignore_first_num_tokens] = model.tokenizer.ignore_index
|
252 |
+
select_index = labels != model.tokenizer.ignore_index
|
253 |
+
correct_predictions = correct_predictions[select_index]
|
254 |
+
|
255 |
+
loss = torch.nn.functional.cross_entropy(
|
256 |
+
logits_loss, labels, ignore_index=model.tokenizer.ignore_index, reduction="none"
|
257 |
+
)
|
258 |
+
acc.append(correct_predictions.mean() * 100.0)
|
259 |
+
loss_list.append(loss.mean())
|
260 |
+
|
261 |
+
# Decode the predicted latents
|
262 |
+
if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0:
|
263 |
+
vid_decoded = model.tokenizer.video_tokenizer.decode(gen_video_tokens_vis_BTHW.cuda())
|
264 |
+
else:
|
265 |
+
vid_decoded = model.tokenizer.video_tokenizer.decode_with_overlap(
|
266 |
+
gen_video_tokens_vis_BTHW.cuda(),
|
267 |
+
temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap,
|
268 |
+
)
|
269 |
+
# normalize decoded images from [-1, 1] to [0, 1], and clip value
|
270 |
+
vid_decoded = (vid_decoded * 0.5 + 0.5).clamp_(0, 1)
|
271 |
+
vid_decoded = vid_decoded[0]
|
272 |
+
|
273 |
+
# Decode the GT latents
|
274 |
+
if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0:
|
275 |
+
vid_rec = model.tokenizer.video_tokenizer.decode(real_video_tokens_vis_BTHW.cuda())
|
276 |
+
else:
|
277 |
+
vid_rec = model.tokenizer.video_tokenizer.decode_with_overlap(
|
278 |
+
real_video_tokens_vis_BTHW.cuda(),
|
279 |
+
temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap,
|
280 |
+
)
|
281 |
+
# normalize decoded image from [-1, 1] to [0, 1], and clip value
|
282 |
+
vid_rec = (vid_rec * 0.5 + 0.5).clamp_(0, 1)
|
283 |
+
vid_rec = vid_rec[0]
|
284 |
+
|
285 |
+
vid_input = input_vid[sample_num] # [-1, 1], input_vid shape: [B, C, L, H, W]
|
286 |
+
vid_input = (vid_input * 0.5 + 0.5).clamp_(0, 1).cuda() # Convert to [0, 1], [C, L, H, W]
|
287 |
+
|
288 |
+
# Subsample real and generated video frames
|
289 |
+
input_video_frames = vid_input.transpose(0, 1) # [L, C, H, W]
|
290 |
+
rec_video_frames = vid_rec.transpose(0, 1)
|
291 |
+
gen_video_frames = vid_decoded.transpose(0, 1)
|
292 |
+
out_videos_gen.append(gen_video_frames)
|
293 |
+
out_videos_rec.append(rec_video_frames)
|
294 |
+
out_videos_gt.append(input_video_frames)
|
295 |
+
|
296 |
+
stride = math.ceil(rec_video_frames.shape[0] / self.num_frames_to_display)
|
297 |
+
|
298 |
+
input_video_frames_subsampled = resize_image(input_video_frames[0::stride], resize_factor=0.5)
|
299 |
+
input_video_frames_subsampled = torchvision.utils.make_grid(
|
300 |
+
input_video_frames_subsampled, nrow=input_video_frames_subsampled.shape[0]
|
301 |
+
)
|
302 |
+
|
303 |
+
gt_video_frames_subsampled = resize_image(rec_video_frames[0::stride], resize_factor=0.5)
|
304 |
+
gt_video_frames_subsampled = torchvision.utils.make_grid(
|
305 |
+
gt_video_frames_subsampled, nrow=gt_video_frames_subsampled.shape[0]
|
306 |
+
)
|
307 |
+
gen_video_frames_subsampled = resize_image(gen_video_frames[0::stride], resize_factor=0.5)
|
308 |
+
gen_video_frames_subsampled = torchvision.utils.make_grid(
|
309 |
+
gen_video_frames_subsampled, nrow=gen_video_frames_subsampled.shape[0]
|
310 |
+
)
|
311 |
+
|
312 |
+
out_frames.append(input_video_frames_subsampled)
|
313 |
+
out_frames.append(gt_video_frames_subsampled)
|
314 |
+
out_frames.append(gen_video_frames_subsampled)
|
315 |
+
|
316 |
+
scaled_num_rank_to_log = (
|
317 |
+
self.num_file_to_log
|
318 |
+
* parallel_state.get_context_parallel_world_size()
|
319 |
+
* parallel_state.get_tensor_model_parallel_world_size()
|
320 |
+
)
|
321 |
+
if self.rank < scaled_num_rank_to_log and not skip_save_file:
|
322 |
+
local_path = f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_{self.rank:04d}.jpg"
|
323 |
+
out_image_grid = torchvision.utils.make_grid(out_frames, nrow=1, padding=0, normalize=False)
|
324 |
+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
325 |
+
torchvision.utils.save_image(out_image_grid, local_path)
|
326 |
+
|
327 |
+
# Log to wandb
|
328 |
+
avg_acc = distributed.dist_reduce_tensor(torch.stack(acc).mean()).item()
|
329 |
+
avg_loss = distributed.dist_reduce_tensor(torch.stack(loss_list).mean()).item()
|
330 |
+
log_info = ""
|
331 |
+
if "acc" in output_batch:
|
332 |
+
log_info = f"train acc: {(output_batch['acc'].mean().item()):.6f}%"
|
333 |
+
if percent_token_diff is not None:
|
334 |
+
log_info += f"; percent_token_diff_train_val: {percent_token_diff.item() * 100:.6f}%"
|
335 |
+
log.info(
|
336 |
+
f"Eval iteration {iteration} teacher-forcing accuracy: {avg_acc:.6f}%, loss: {avg_loss:.4f}; {log_info}"
|
337 |
+
)
|
338 |
+
if self.rank == 0 and wandb.run:
|
339 |
+
local_files = glob.glob(f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_*.jpg")
|
340 |
+
local_files = sorted(local_files)[: self.num_file_to_log]
|
341 |
+
if captions is None:
|
342 |
+
captions = ["vid_frames_teacher_forcing"] * len(local_files)
|
343 |
+
for local_path, caption in zip(local_files, captions):
|
344 |
+
wandb.log(
|
345 |
+
{"frames": [wandb.Image(local_path, caption=caption)]},
|
346 |
+
step=iteration,
|
347 |
+
)
|
348 |
+
|
349 |
+
wandb.log({"eval/teacher_forcing_acc": avg_acc}, step=iteration)
|
350 |
+
wandb.log({"eval/teacher_forcing_loss": avg_loss}, step=iteration)
|
351 |
+
if percent_token_diff is not None:
|
352 |
+
wandb.log({"eval/percent_token_diff_train_val": percent_token_diff.item() * 100}, step=iteration)
|
cosmos_predict1/autoregressive/configs/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_predict1/autoregressive/configs/base/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_predict1/autoregressive/configs/base/callbacks.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from cosmos_predict1.autoregressive.callbacks.video_sampling_teacher_forcing import VideoSamplingTeacherForcing
|
17 |
+
from cosmos_predict1.callbacks.grad_clip import GradClip
|
18 |
+
from cosmos_predict1.utils.callback import ProgressBarCallback
|
19 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
20 |
+
|
21 |
+
BASIC_CALLBACKS = dict(
|
22 |
+
progress_bar=L(ProgressBarCallback)(),
|
23 |
+
grad_clip=L(GradClip)(clip_norm=1.0, fsdp_enabled="${model.model_config.fsdp_enabled}", model_key="model"),
|
24 |
+
)
|
25 |
+
|
26 |
+
VIDEO_TEACHER_FORCING_CALLBACK = dict(
|
27 |
+
vid_sampling_tf=L(VideoSamplingTeacherForcing)(
|
28 |
+
every_n=500,
|
29 |
+
video_latent_shape="${model.model_config.video_latent_shape}",
|
30 |
+
num_frames_to_display=4,
|
31 |
+
save_folder="video_sampling_teacher_forcing",
|
32 |
+
)
|
33 |
+
)
|
cosmos_predict1/autoregressive/configs/base/dataloader.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from megatron.core import parallel_state
|
17 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
18 |
+
|
19 |
+
from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig
|
20 |
+
from cosmos_predict1.autoregressive.datasets.video_dataset import VideoDataset
|
21 |
+
from cosmos_predict1.utils import log
|
22 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
23 |
+
|
24 |
+
DATALOADER_OPTIONS = {}
|
25 |
+
|
26 |
+
|
27 |
+
def get_sampler(dataset):
|
28 |
+
return DistributedSampler(
|
29 |
+
dataset,
|
30 |
+
num_replicas=parallel_state.get_data_parallel_world_size(),
|
31 |
+
rank=parallel_state.get_data_parallel_rank(),
|
32 |
+
shuffle=True,
|
33 |
+
seed=0,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def dataloader_register(key):
|
38 |
+
log.info(f"registering dataloader {key}...")
|
39 |
+
|
40 |
+
def decorator(func):
|
41 |
+
DATALOADER_OPTIONS[key] = func
|
42 |
+
return func
|
43 |
+
|
44 |
+
return decorator
|
45 |
+
|
46 |
+
|
47 |
+
@dataloader_register("tealrobot_video")
|
48 |
+
def get_tealrobot_video(
|
49 |
+
batch_size: int = 1,
|
50 |
+
dataset_dir: str = "datasets/cosmos_nemo_assets/videos/",
|
51 |
+
sequence_interval: int = 1,
|
52 |
+
num_frames: int = 33,
|
53 |
+
video_size: list[int, int] = [640, 848],
|
54 |
+
start_frame_interval: int = 1,
|
55 |
+
):
|
56 |
+
dataset = L(VideoDataset)(
|
57 |
+
config=VideoDatasetConfig(
|
58 |
+
dataset_dir=dataset_dir,
|
59 |
+
sequence_interval=sequence_interval,
|
60 |
+
num_frames=num_frames,
|
61 |
+
video_size=video_size,
|
62 |
+
start_frame_interval=start_frame_interval,
|
63 |
+
)
|
64 |
+
)
|
65 |
+
return L(DataLoader)(
|
66 |
+
dataset=dataset,
|
67 |
+
sampler=L(get_sampler)(dataset=dataset),
|
68 |
+
batch_size=batch_size,
|
69 |
+
drop_last=True,
|
70 |
+
pin_memory=True,
|
71 |
+
num_workers=8,
|
72 |
+
)
|
cosmos_predict1/autoregressive/configs/base/dataset.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Dataset config class."""
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from cosmos_predict1.utils.config import make_freezable
|
21 |
+
|
22 |
+
|
23 |
+
@make_freezable
|
24 |
+
@attrs.define(slots=False)
|
25 |
+
class VideoDatasetConfig:
|
26 |
+
"""
|
27 |
+
Args:
|
28 |
+
dataset_dir (str): Base path to the dataset directory
|
29 |
+
sequence_interval (int): Interval between sampled frames in a sequence
|
30 |
+
num_frames (int): Number of frames to load per sequence
|
31 |
+
video_size (list): Target size [H,W] for video frames
|
32 |
+
start_frame_interval (int): Interval between starting frames of sequences
|
33 |
+
"""
|
34 |
+
|
35 |
+
dataset_dir: str = "datasets/cosmos_nemo_assets/videos/"
|
36 |
+
sequence_interval: int = 1
|
37 |
+
num_frames: int = 33
|
38 |
+
video_size: list[int, int] = [640, 848]
|
39 |
+
start_frame_interval: int = 1
|
cosmos_predict1/autoregressive/configs/base/model.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig
|
21 |
+
from cosmos_predict1.utils import config
|
22 |
+
|
23 |
+
_ACTION_DIM = 8
|
24 |
+
from cosmos_predict1.utils.lazy_config import LazyDict
|
25 |
+
|
26 |
+
|
27 |
+
@attrs.define
|
28 |
+
class ModelConfig:
|
29 |
+
"""
|
30 |
+
A class to hold model configuration arguments.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
dim (int): The dimensionality of the input and output of each transformer block.
|
34 |
+
n_layers (int): Number of layers in the transformer.
|
35 |
+
n_heads (int): Number of attention heads.
|
36 |
+
n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
|
37 |
+
`num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
|
38 |
+
head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
|
39 |
+
vocab_size (int): Vocabulary size.
|
40 |
+
ffn_hidden_size (int): Hidden size for feedforward network.
|
41 |
+
norm_eps (float): Epsilon value for normalization.
|
42 |
+
rope_theta (float): Theta value for rotary positional embeddings.
|
43 |
+
apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
|
44 |
+
max_batch_size (int): Maximum batch size for inference.
|
45 |
+
max_seq_len (int): Maximum sequence length for input text.
|
46 |
+
fuse_qkv (bool): Whether to fuse QKV in attention. Defaults to True.
|
47 |
+
causal_mask (bool): Whether to use causal mask. Defaults to True.
|
48 |
+
norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
|
49 |
+
precision (str): Data type for the model.
|
50 |
+
use_qk_normalization (bool): Whether to enable QK normalization.
|
51 |
+
tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1.
|
52 |
+
ckpt_dir (str): Checkpoint directory.
|
53 |
+
ckpt_path (str): Checkpoint path.
|
54 |
+
apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
|
55 |
+
yarn_scale (Optional[float]): Scale factor for YaRN.
|
56 |
+
yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
|
57 |
+
yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
|
58 |
+
original_seq_len (Optional[int]): Original sequence length.
|
59 |
+
vision_encoder (Optional[str]): Vision encoder name.
|
60 |
+
mm_projector (Optional[str]): Multi-modal projector name.
|
61 |
+
vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
|
62 |
+
rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "3D".
|
63 |
+
pytorch_rope_version (Optional[str]): Version of the PyTorch RoPE implementation. Choices: "v1", "v2".
|
64 |
+
original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
|
65 |
+
pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
|
66 |
+
vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3.
|
67 |
+
insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
|
68 |
+
insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
|
69 |
+
context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
|
70 |
+
num_video_frames (Optional[int]): Number of video frames.
|
71 |
+
video_height (Optional[int]): Raw video pixel height dimension.
|
72 |
+
video_width (Optional[int]): Raw video pixel width dimension.
|
73 |
+
video_latent_shape (Optional[list]): Video tokenizer output dimension, in (T,H,W).
|
74 |
+
"""
|
75 |
+
|
76 |
+
dim: int = attrs.field(default=4096)
|
77 |
+
n_layers: int = attrs.field(default=32)
|
78 |
+
n_heads: int = attrs.field(default=32)
|
79 |
+
n_kv_heads: Optional[int] = attrs.field(default=8)
|
80 |
+
head_dim: Optional[int] = attrs.field(default=None)
|
81 |
+
vocab_size: int = attrs.field(default=128256)
|
82 |
+
ffn_hidden_size: int = attrs.field(default=14336)
|
83 |
+
norm_eps: float = attrs.field(default=1e-5)
|
84 |
+
rope_theta: float = attrs.field(default=500000)
|
85 |
+
apply_abs_pos_emb: bool = attrs.field(default=False)
|
86 |
+
max_batch_size: int = attrs.field(default=1)
|
87 |
+
max_seq_len: int = attrs.field(default=8192)
|
88 |
+
fuse_qkv: bool = attrs.field(default=False)
|
89 |
+
causal_mask: bool = attrs.field(default=True)
|
90 |
+
norm_type: str = attrs.field(default="rmsnorm")
|
91 |
+
precision: str = attrs.field(default="bfloat16")
|
92 |
+
use_qk_normalization: bool = False
|
93 |
+
tokenizer: Optional[TokenizerConfig] = None
|
94 |
+
tensor_model_parallel_size: int = attrs.field(default=1)
|
95 |
+
ckpt_dir: Optional[str] = attrs.field(default=None)
|
96 |
+
ckpt_path: Optional[str] = attrs.field(
|
97 |
+
default=None
|
98 |
+
) # If not None, load the model from this path instead of ckpt_dir
|
99 |
+
apply_yarn: Optional[bool] = attrs.field(default=False)
|
100 |
+
yarn_scale: Optional[float] = attrs.field(default=None)
|
101 |
+
yarn_beta_fast: Optional[int] = attrs.field(default=None)
|
102 |
+
yarn_beta_slow: Optional[int] = attrs.field(default=None)
|
103 |
+
original_seq_len: Optional[int] = attrs.field(default=None)
|
104 |
+
vision_encoder: Optional[str] = attrs.field(default=None)
|
105 |
+
vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
|
106 |
+
mm_projector: Optional[str] = attrs.field(default=None)
|
107 |
+
rope_dim: Optional[str] = attrs.field(default="1D")
|
108 |
+
pytorch_rope_version: Optional[str] = attrs.field(default="v2")
|
109 |
+
original_latent_shape: Optional[list] = None
|
110 |
+
pad_to_multiple_of: Optional[int] = None
|
111 |
+
vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
|
112 |
+
insert_cross_attn: bool = False
|
113 |
+
insert_cross_attn_every_k_layers: int = 1
|
114 |
+
context_dim: Optional[int] = attrs.field(default=1024)
|
115 |
+
# For video training
|
116 |
+
num_video_frames: Optional[int] = None
|
117 |
+
# Raw video pixel dimension
|
118 |
+
video_height: Optional[int] = None
|
119 |
+
video_width: Optional[int] = None
|
120 |
+
# Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
|
121 |
+
video_latent_shape: Optional[list] = None
|
122 |
+
|
123 |
+
def __getitem__(self, item):
|
124 |
+
return getattr(self, item)
|
125 |
+
|
126 |
+
|
127 |
+
@attrs.define
|
128 |
+
class TrainingModelConfig:
|
129 |
+
"""
|
130 |
+
A class to hold model configuration arguments.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
dim (int): The dimensionality of the input and output of each transformer block.
|
134 |
+
n_layers (int): Number of layers in the transformer.
|
135 |
+
n_heads (int): Number of attention heads.
|
136 |
+
n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
|
137 |
+
`num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
|
138 |
+
head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
|
139 |
+
vocab_size (int): Vocabulary size.
|
140 |
+
multiple_of (int): Ensures the hidden layer size is a multiple of this value for SwiGLU activation.
|
141 |
+
ffn_dim_multiplier (Optional[float]): Multiplier for feedforward network dimension.
|
142 |
+
ffn_hidden_size (Optional[int]): Hidden size for feedforward network. If None, use ffn_dim_multiplier to compute it.
|
143 |
+
norm_eps (float): Epsilon value for normalization.
|
144 |
+
rope_theta (float): Theta value for rotary positional embeddings.
|
145 |
+
apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
|
146 |
+
max_batch_size (int): Maximum batch size for inference (determines KV cache size).
|
147 |
+
max_seq_len (int): Maximum sequence length for input text (determines KV cache size).
|
148 |
+
fuse_qkv (bool): Whether to fuse QKV in attention. Flag for the pytorch backend.
|
149 |
+
causal_mask (bool): Whether to use causal mask. Defaults to True.
|
150 |
+
flash_attn (bool): Whether to use Flash attention.
|
151 |
+
norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
|
152 |
+
backend (str): Backend for the model.
|
153 |
+
precision (str): Data type for the model.
|
154 |
+
ema (config.EMAConfig): Configuration for exponential moving average.
|
155 |
+
embedding_dropout(float): Dropout rate for the embedding layer.
|
156 |
+
attention_dropout(float): Dropout rate for attention.
|
157 |
+
hidden_dropout(float): Dropout after the attention and feed-forward layers (following TransformerEngine's
|
158 |
+
implementation in its TransformerLayer class).
|
159 |
+
use_qk_normalization (bool): Whether to enable QK normalization.
|
160 |
+
inference (bool): Whether the model is used for inference.
|
161 |
+
act_ckpt_enabled (bool): Whether to enable activation checkpointing.
|
162 |
+
fsdp_enabled (bool): Whether to enable FSDP.
|
163 |
+
fsdp (LazyDict): Configuration for FSDP.
|
164 |
+
ckpt_dir (str): Checkpoint directory.
|
165 |
+
ckpt_path (str): Checkpoint path.
|
166 |
+
cache_dir (str): Cache directory.
|
167 |
+
apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
|
168 |
+
yarn_scale (Optional[float]): Scale factor for YaRN.
|
169 |
+
yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
|
170 |
+
yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
|
171 |
+
original_seq_len (Optional[int]): Original sequence length.
|
172 |
+
depth_init (bool): If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the
|
173 |
+
total number of transformer blocks. Defaults to `True` (following the TorchTitan implementation of Llama3).
|
174 |
+
context_parallel_size (int): Context parallel size. Defaults to 1.
|
175 |
+
tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1.
|
176 |
+
sequence_parallel (bool): Whether to use sequence parallelism. Defaults to False.
|
177 |
+
set_parallel_mode (bool): It is a boolean flag used by TransformerEngine to handle Tensor Parallelism.
|
178 |
+
Essentially, it is equivalent to `tensor_model_parallel_size > 1`. Defaults to `False`.
|
179 |
+
attention_tp (bool): Whether to use tensor parallelism for attention layers.
|
180 |
+
mm_projector (Optional[str]): Multimodal projector used for vision-language modeling. Defaults to None.
|
181 |
+
Choices: "identity", "linear", "mlp", "mlp_downsample".
|
182 |
+
video_latent_shape (Optional[list]): Shape of the video latent tensor. [T, H, W]
|
183 |
+
image_latent_shape (Optional[list]): Shape of the image latent tensor. [H, W]
|
184 |
+
num_video_frames (Optional[int]): Number of video frames.
|
185 |
+
rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D".
|
186 |
+
pytorch_rope_version (Optional[str]): Version of the RoPE for the `pytorch` backend. "v1" is the Llama implementation, and "v2" is HuggingFace/TransformerEngine implementation.
|
187 |
+
original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
|
188 |
+
pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
|
189 |
+
peft_last_n_layers (Optional[int]): Number of last few layers to fine-tune in Parameter Efficient Fine-Tuning (PEFT). When this and peft_every_n_layers are both 0, it means all layers are fine-tuned (FFT).
|
190 |
+
peft_every_n_layers (Optional[int]): In Parameter Efficient Fine-Tuning (PEFT), every n layers are unfrozen and can be trained (in flamingo style). When this and peft_last_n_layers are both 0,
|
191 |
+
it means all layers are fine-tuned (FFT). For example, for a 40 layer model, n=8 means training layers 7, 15, 23, 31, 39, which includes the final layer.
|
192 |
+
It is advised to pick n such that the final layer is included.
|
193 |
+
freeze_vision_encoder (bool): Whether to freeze the vision encoder in vision-language model training. Defaults to False.
|
194 |
+
vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
|
195 |
+
insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
|
196 |
+
insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
|
197 |
+
context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
|
198 |
+
finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn).
|
199 |
+
finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn).
|
200 |
+
use_action_condition (bool): Whether to use the robot action condition.
|
201 |
+
action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp".
|
202 |
+
action_dim (Optional[int]): The dimensionality of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]).
|
203 |
+
action_embedding_dim (Optional[int]): The dimensionality of the robot action embedding.
|
204 |
+
group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal".
|
205 |
+
sync_1d_parameters (bool): Whether to synchronize layernorm parameters (1D) across tensor parallel ranks (default True).
|
206 |
+
Note: this is to ensure all TP-ranks have the same layernorm parameters.
|
207 |
+
z_loss_coeff (float): The coefficient for the z-loss.
|
208 |
+
insert_medusa_head (bool): Whether to insert the Medusa head.
|
209 |
+
ft_medusa_option (str): Options on which layers to finetune, choices like:
|
210 |
+
"fft": fully fine-tune both medusa heads and all LLM backbone;
|
211 |
+
"head": fine-tune medusa heads;
|
212 |
+
"head_out": fine-tune medusa heads, and the output layer;
|
213 |
+
"head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone.
|
214 |
+
medusa_num_heads (int): Number of heads in the Medusa head.
|
215 |
+
medusa_num_layers (int): Number of layers in the Medusa head.
|
216 |
+
medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1.
|
217 |
+
zero_init_cross_attn_proj (bool): Whether to initialize the cross-attn proj layer with zeros (default False).
|
218 |
+
concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False).
|
219 |
+
"""
|
220 |
+
|
221 |
+
dim: int = attrs.field(default=4096)
|
222 |
+
n_layers: int = attrs.field(default=32)
|
223 |
+
n_heads: int = attrs.field(default=32)
|
224 |
+
n_kv_heads: Optional[int] = attrs.field(default=8)
|
225 |
+
head_dim: Optional[int] = attrs.field(default=None)
|
226 |
+
vocab_size: int = attrs.field(default=128256)
|
227 |
+
multiple_of: int = attrs.field(default=1024) # make SwiGLU hidden layer size multiple of large power of 2
|
228 |
+
ffn_dim_multiplier: Optional[float] = attrs.field(default=1.3)
|
229 |
+
ffn_hidden_size: Optional[int] = attrs.field(default=None)
|
230 |
+
norm_eps: float = attrs.field(default=1e-5)
|
231 |
+
rope_theta: float = attrs.field(default=500000)
|
232 |
+
apply_abs_pos_emb: bool = attrs.field(default=False)
|
233 |
+
max_batch_size: int = attrs.field(default=1)
|
234 |
+
max_seq_len: int = attrs.field(default=8192)
|
235 |
+
fuse_qkv: bool = attrs.field(default=False)
|
236 |
+
causal_mask: bool = attrs.field(default=True)
|
237 |
+
flash_attn: bool = attrs.field(default=True)
|
238 |
+
norm_type: str = attrs.field(default="rmsnorm")
|
239 |
+
backend: str = attrs.field(default="pytorch")
|
240 |
+
precision: str = attrs.field(default="bfloat16")
|
241 |
+
ema: config.EMAConfig = config.EMAConfig(enabled=False)
|
242 |
+
embedding_dropout: float = 0.0
|
243 |
+
attention_dropout: float = 0.0
|
244 |
+
hidden_dropout: float = 0.0
|
245 |
+
use_qk_normalization: bool = False
|
246 |
+
tokenizer: Optional[TokenizerConfig] = None
|
247 |
+
inference: bool = False
|
248 |
+
act_ckpt_enabled: bool = False
|
249 |
+
fsdp_enabled: bool = False
|
250 |
+
context_parallel_size: int = attrs.field(default=1)
|
251 |
+
tensor_model_parallel_size: int = attrs.field(default=1)
|
252 |
+
sequence_parallel: bool = attrs.field(default=False)
|
253 |
+
set_parallel_mode: bool = attrs.field(default=False)
|
254 |
+
fsdp: LazyDict = LazyDict(
|
255 |
+
dict(
|
256 |
+
policy="auto", # choices: ["size", "auto"]
|
257 |
+
min_num_params=1024, # Used as policy == "size"
|
258 |
+
sharding_strategy="hybrid", # Choices: ["full", "hybrid"]. "full" means sharding_group_size = world_size
|
259 |
+
sharding_group_size=8, # If None, defaults to min(world_size, 8). Recommends 8 for training on 8-GPU nodes.
|
260 |
+
)
|
261 |
+
)
|
262 |
+
ckpt_dir: Optional[str] = attrs.field(default="")
|
263 |
+
ckpt_path: Optional[str] = attrs.field(
|
264 |
+
default=None
|
265 |
+
) # If not None, load the model from this path instead of ckpt_dir
|
266 |
+
cache_dir: Optional[str] = attrs.field(default="/project/cosmos/ar/cache")
|
267 |
+
apply_yarn: Optional[bool] = attrs.field(default=False)
|
268 |
+
yarn_scale: Optional[float] = attrs.field(default=None)
|
269 |
+
yarn_beta_fast: Optional[int] = attrs.field(default=None)
|
270 |
+
yarn_beta_slow: Optional[int] = attrs.field(default=None)
|
271 |
+
original_seq_len: Optional[int] = attrs.field(default=None)
|
272 |
+
depth_init: bool = attrs.field(default=True)
|
273 |
+
ignore_first_num_tokens: int = 0
|
274 |
+
z_loss_coeff: float = 1e-4
|
275 |
+
attention_tp: bool = False
|
276 |
+
vision_encoder: Optional[str] = attrs.field(default=None)
|
277 |
+
mm_projector: Optional[str] = attrs.field(default=None)
|
278 |
+
rope_dim: Optional[str] = attrs.field(default="1D")
|
279 |
+
pytorch_rope_version: Optional[str] = attrs.field(default="v2")
|
280 |
+
original_latent_shape: Optional[list] = None
|
281 |
+
pad_to_multiple_of: Optional[int] = None
|
282 |
+
peft_last_n_layers: Optional[int] = attrs.field(default=0)
|
283 |
+
peft_every_n_layers: Optional[int] = attrs.field(default=0)
|
284 |
+
freeze_vision_encoder: bool = False
|
285 |
+
vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
|
286 |
+
insert_cross_attn: bool = False
|
287 |
+
insert_cross_attn_every_k_layers: int = 1
|
288 |
+
context_dim: Optional[int] = attrs.field(default=1024)
|
289 |
+
finetune_layers_with_cross_attn: bool = False
|
290 |
+
finetune_layers_without_cross_attn: bool = False
|
291 |
+
use_action_condition: bool = False
|
292 |
+
action_embedding_mode: Optional[str] = attrs.field(default="mlp")
|
293 |
+
action_dim: Optional[int] = attrs.field(default=_ACTION_DIM)
|
294 |
+
action_embedding_dim: Optional[int] = attrs.field(default=1024)
|
295 |
+
group_causal_mask_mode: Optional[str] = attrs.field(default=None)
|
296 |
+
sync_1d_parameters: bool = True
|
297 |
+
# hyper-parameters for the medusa head configs
|
298 |
+
insert_medusa_head: bool = False
|
299 |
+
ft_medusa_option: str = "fft"
|
300 |
+
medusa_num_heads: int = 7
|
301 |
+
medusa_num_layers: int = 1
|
302 |
+
medusa_concat_heads: bool = True
|
303 |
+
# For video training
|
304 |
+
num_video_frames: Optional[int] = None
|
305 |
+
# Raw video pixel dimension
|
306 |
+
video_height: Optional[int] = None
|
307 |
+
video_width: Optional[int] = None
|
308 |
+
# Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
|
309 |
+
video_latent_shape: Optional[list] = None
|
310 |
+
# For image training
|
311 |
+
image_latent_shape: Optional[list] = None
|
312 |
+
# For robot training (action)
|
313 |
+
zero_init_cross_attn_proj: bool = False
|
314 |
+
# For robot training (action)
|
315 |
+
concat_action_to_context: bool = False
|
316 |
+
|
317 |
+
def __getitem__(self, item):
|
318 |
+
return getattr(self, item)
|
cosmos_predict1/autoregressive/configs/base/model_config.py
ADDED
@@ -0,0 +1,718 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import copy
|
17 |
+
from typing import Callable, List, Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from megatron.core import ModelParallelConfig
|
21 |
+
|
22 |
+
from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TrainingModelConfig
|
23 |
+
from cosmos_predict1.autoregressive.configs.base.tokenizer import (
|
24 |
+
TextTokenizerConfig,
|
25 |
+
TokenizerConfig,
|
26 |
+
VideoTokenizerConfig,
|
27 |
+
create_discrete_video_fsq_tokenizer_state_dict_config,
|
28 |
+
)
|
29 |
+
from cosmos_predict1.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer
|
30 |
+
from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer
|
31 |
+
from cosmos_predict1.autoregressive.training.model import AutoRegressiveTrainingModel
|
32 |
+
from cosmos_predict1.utils import log
|
33 |
+
from cosmos_predict1.utils.config import EMAConfig
|
34 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
35 |
+
|
36 |
+
# Common architecture specifications
|
37 |
+
BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336}
|
38 |
+
COSMOS_ARCHITECTURES = {
|
39 |
+
"1b": {
|
40 |
+
"n_layers": 16,
|
41 |
+
"dim": 2048,
|
42 |
+
"n_heads": 32,
|
43 |
+
},
|
44 |
+
"4b": {
|
45 |
+
"n_layers": 16,
|
46 |
+
"dim": 4096,
|
47 |
+
"n_heads": 32,
|
48 |
+
},
|
49 |
+
"12b": {
|
50 |
+
"n_layers": 40,
|
51 |
+
"dim": 5120,
|
52 |
+
"n_heads": 32,
|
53 |
+
"head_dim": 128,
|
54 |
+
},
|
55 |
+
}
|
56 |
+
|
57 |
+
COSMOS_YARN_CONFIG = {
|
58 |
+
"original_latent_shape": [3, 40, 64],
|
59 |
+
"apply_yarn": True,
|
60 |
+
"yarn_beta_fast": 4,
|
61 |
+
"yarn_beta_slow": 1,
|
62 |
+
"yarn_scale": 2,
|
63 |
+
}
|
64 |
+
|
65 |
+
# Llama3 architecture specifications for different model sizes
|
66 |
+
LLAMA3_ARCHITECTURES = {
|
67 |
+
"8b": {
|
68 |
+
"n_layers": 32,
|
69 |
+
"dim": 4096,
|
70 |
+
"n_heads": 32,
|
71 |
+
"ffn_hidden_size": 14336,
|
72 |
+
},
|
73 |
+
}
|
74 |
+
# Llama3.1 uses YaRN for long context support (context of 128k tokens)
|
75 |
+
LLAMA_YARN_CONFIG = {
|
76 |
+
"apply_yarn": True,
|
77 |
+
"yarn_scale": 8,
|
78 |
+
"yarn_beta_fast": 4,
|
79 |
+
"yarn_beta_slow": 1,
|
80 |
+
}
|
81 |
+
|
82 |
+
# Mistral architecture specifications for different model sizes
|
83 |
+
MISTRAL_ARCHITECTURES = {
|
84 |
+
"12b": {
|
85 |
+
"n_layers": 40,
|
86 |
+
"dim": 5120,
|
87 |
+
"n_heads": 32,
|
88 |
+
"ffn_hidden_size": 14336,
|
89 |
+
"head_dim": 128,
|
90 |
+
},
|
91 |
+
}
|
92 |
+
|
93 |
+
PIXTRAL_VISION_ARCHITECTURES = {
|
94 |
+
"12b": {"vision_encoder": "pixtral-12b-vit", "mm_projector": "mlp"},
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
def get_model_arch_specs(model_size: str, model_family: str = "mistral", pretrained: bool = False) -> dict:
|
99 |
+
"""
|
100 |
+
Get the model architecture specifications for the given model size, model family and pretrained status.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", etc.
|
104 |
+
model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral"
|
105 |
+
pretrained (bool): Whether to load pretrained weights.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
dict: A dictionary containing the model architecture specifications.
|
109 |
+
"""
|
110 |
+
arch_specs = copy.deepcopy(BASE_CONFIG)
|
111 |
+
model_size = model_size.lower()
|
112 |
+
if model_family.startswith("cosmos"):
|
113 |
+
arch_specs.update(COSMOS_ARCHITECTURES[model_size])
|
114 |
+
elif model_family.startswith("llama"):
|
115 |
+
arch_specs.update(LLAMA3_ARCHITECTURES[model_size])
|
116 |
+
elif model_family in ["mistral", "pixtral"]:
|
117 |
+
arch_specs.update(MISTRAL_ARCHITECTURES[model_size])
|
118 |
+
if model_family == "pixtral":
|
119 |
+
arch_specs.update(PIXTRAL_VISION_ARCHITECTURES[model_size])
|
120 |
+
else:
|
121 |
+
raise ValueError(f"Model family {model_family} is not supported.")
|
122 |
+
|
123 |
+
if pretrained:
|
124 |
+
if model_family == "cosmos":
|
125 |
+
if model_size == "12b":
|
126 |
+
arch_specs.update(COSMOS_YARN_CONFIG)
|
127 |
+
log.debug(f"Using YaRN for RoPE extension with config: {COSMOS_YARN_CONFIG}")
|
128 |
+
else:
|
129 |
+
pass
|
130 |
+
elif model_family in ["llama", "llama3"]:
|
131 |
+
pretrained_specs = {
|
132 |
+
"rope_theta": 500000,
|
133 |
+
"max_seq_len": 8192,
|
134 |
+
"vocab_size": 128256,
|
135 |
+
}
|
136 |
+
arch_specs.update(pretrained_specs)
|
137 |
+
elif model_family == "llama3.1":
|
138 |
+
pretrained_specs = {
|
139 |
+
"rope_theta": 500000,
|
140 |
+
"max_seq_len": 131072,
|
141 |
+
"original_seq_len": 8192,
|
142 |
+
"vocab_size": 128256,
|
143 |
+
**LLAMA_YARN_CONFIG,
|
144 |
+
}
|
145 |
+
arch_specs.update(pretrained_specs)
|
146 |
+
elif model_family == "mistral":
|
147 |
+
assert model_size == "12b", "We only support Mistral-Nemo-12B model."
|
148 |
+
pretrained_specs = {
|
149 |
+
"rope_theta": 1000000,
|
150 |
+
"max_seq_len": 128000,
|
151 |
+
"vocab_size": 131072,
|
152 |
+
}
|
153 |
+
arch_specs.update(pretrained_specs)
|
154 |
+
elif model_family == "pixtral":
|
155 |
+
assert model_size == "12b", "We only support Pixtral 12B model."
|
156 |
+
pretrained_specs = {"rope_theta": 1000000000, "max_seq_len": 128000, "vocab_size": 131072}
|
157 |
+
arch_specs.update(pretrained_specs)
|
158 |
+
else:
|
159 |
+
raise ValueError(f"Model family {model_family} doesn't have a pretrained config.")
|
160 |
+
|
161 |
+
return arch_specs
|
162 |
+
|
163 |
+
|
164 |
+
def create_text_model_config(
|
165 |
+
model_ckpt_path: str,
|
166 |
+
tokenizer_path: str,
|
167 |
+
tensor_model_parallel_size: int = 1,
|
168 |
+
model_family: str = "mistral",
|
169 |
+
model_size: str = "12b",
|
170 |
+
is_instruct_model: bool = True,
|
171 |
+
max_seq_len: int = None,
|
172 |
+
max_batch_size: int = 1,
|
173 |
+
rope_dim: str = "1D",
|
174 |
+
add_special_tokens: bool = True,
|
175 |
+
pytorch_rope_version: str = None,
|
176 |
+
) -> dict:
|
177 |
+
"""Create a text model for training or inference.
|
178 |
+
Args:
|
179 |
+
model_ckpt_path (str): Path to the model checkpoint.
|
180 |
+
tokenizer_path (str): Path to the tokenizer folder.
|
181 |
+
tensor_model_parallel_size (int): Number of tensor model parallel groups.
|
182 |
+
model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
|
183 |
+
model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", "8b", "72b", etc.
|
184 |
+
is_instruct_model (bool): Whether the model is an instruct model.
|
185 |
+
inference (bool): Whether to create the model for inference.
|
186 |
+
max_seq_len (int): Maximum sequence length.
|
187 |
+
max_batch_size (int): Maximum batch size.
|
188 |
+
rope_dim (str): RoPE dimension. Choices: "1D", "3D".
|
189 |
+
add_special_tokens (bool): Whether to add special tokens.
|
190 |
+
Returns:
|
191 |
+
dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
|
192 |
+
"""
|
193 |
+
# Model size specific parameters
|
194 |
+
model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
|
195 |
+
if max_seq_len is not None:
|
196 |
+
# Override the max_seq_len if provided
|
197 |
+
model_arch_specs["max_seq_len"] = max_seq_len
|
198 |
+
if pytorch_rope_version is not None:
|
199 |
+
model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
|
200 |
+
model_config = ModelConfig(
|
201 |
+
max_batch_size=max_batch_size,
|
202 |
+
precision="bfloat16",
|
203 |
+
ckpt_path=model_ckpt_path,
|
204 |
+
use_qk_normalization=False,
|
205 |
+
tensor_model_parallel_size=tensor_model_parallel_size,
|
206 |
+
rope_dim=rope_dim,
|
207 |
+
**model_arch_specs,
|
208 |
+
)
|
209 |
+
|
210 |
+
tokenizer_config = TokenizerConfig(
|
211 |
+
text_tokenizer=TextTokenizerConfig(
|
212 |
+
config=L(TextTokenizer)(
|
213 |
+
model_family=model_family,
|
214 |
+
is_instruct_model=is_instruct_model,
|
215 |
+
local_path=tokenizer_path,
|
216 |
+
),
|
217 |
+
data_key="text",
|
218 |
+
tokenizer_offset=model_config.vocab_size,
|
219 |
+
tokenize_here=False,
|
220 |
+
vocab_size=model_config.vocab_size,
|
221 |
+
),
|
222 |
+
seq_len=model_config.max_seq_len,
|
223 |
+
training_type="text_only",
|
224 |
+
add_special_tokens=add_special_tokens,
|
225 |
+
)
|
226 |
+
return model_config, tokenizer_config
|
227 |
+
|
228 |
+
|
229 |
+
def create_vision_language_model_config(
|
230 |
+
model_ckpt_path: str,
|
231 |
+
tokenizer_ckpt_path: str,
|
232 |
+
tensor_model_parallel_size: int = 1,
|
233 |
+
model_family: str = "pixtral",
|
234 |
+
model_size: str = "12b",
|
235 |
+
is_instruct_model: bool = True,
|
236 |
+
max_batch_size: int = 1,
|
237 |
+
rope_dim: str = "1D",
|
238 |
+
add_special_tokens: bool = True,
|
239 |
+
max_seq_len: int = None,
|
240 |
+
vision_encoder_in_channels: int = 3,
|
241 |
+
fuse_qkv: bool = False,
|
242 |
+
pytorch_rope_version: str = None,
|
243 |
+
) -> dict:
|
244 |
+
"""Create a vision-language model for training or inference.
|
245 |
+
Args:
|
246 |
+
model_ckpt_path (str): Path to the model checkpoint.
|
247 |
+
tokenizer_ckpt_path (str): Path to the tokenizer checkpoint.
|
248 |
+
tensor_model_parallel_size (int): Number of tensor model parallel groups.
|
249 |
+
model_family (str): Model family. Choices: "pixtral".
|
250 |
+
model_size (str): Model size. Choices: "12b".
|
251 |
+
is_instruct_model (bool): Whether the model is an instruct model.
|
252 |
+
rope_dim (str): RoPE dimension. Choices: "1D".
|
253 |
+
add_special_tokens (bool): Whether to add special tokens.
|
254 |
+
max_seq_len (int): Maximum sequence length.
|
255 |
+
vision_encoder_in_channels (int): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4 channel images where last channel is binary mask, set this to 4.
|
256 |
+
fuse_qkv (bool): Whether to fuse the QKV linear layers.
|
257 |
+
Returns:
|
258 |
+
dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
|
259 |
+
"""
|
260 |
+
# Model size specific parameters
|
261 |
+
model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
|
262 |
+
if max_seq_len is not None:
|
263 |
+
# Override the max_seq_len if provided
|
264 |
+
model_arch_specs["max_seq_len"] = max_seq_len
|
265 |
+
if pytorch_rope_version is not None:
|
266 |
+
model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
|
267 |
+
|
268 |
+
model_config = ModelConfig(
|
269 |
+
max_batch_size=max_batch_size,
|
270 |
+
precision="bfloat16",
|
271 |
+
ckpt_path=model_ckpt_path,
|
272 |
+
use_qk_normalization=False,
|
273 |
+
tensor_model_parallel_size=tensor_model_parallel_size,
|
274 |
+
rope_dim=rope_dim,
|
275 |
+
vision_encoder_in_channels=vision_encoder_in_channels,
|
276 |
+
fuse_qkv=fuse_qkv,
|
277 |
+
**model_arch_specs,
|
278 |
+
)
|
279 |
+
# Vision-language tokenizer
|
280 |
+
tokenizer_config = TokenizerConfig(
|
281 |
+
text_tokenizer=TextTokenizerConfig(
|
282 |
+
config=L(ImageTextTokenizer)(
|
283 |
+
model_family=model_family,
|
284 |
+
is_instruct_model=is_instruct_model,
|
285 |
+
image_processor_path=tokenizer_ckpt_path,
|
286 |
+
tokenizer_path=tokenizer_ckpt_path,
|
287 |
+
),
|
288 |
+
data_key="image_text_interleaved",
|
289 |
+
tokenizer_offset=model_config.vocab_size,
|
290 |
+
tokenize_here=False,
|
291 |
+
vocab_size=model_config.vocab_size,
|
292 |
+
),
|
293 |
+
seq_len=model_config.max_seq_len,
|
294 |
+
training_type="image_text_interleaved",
|
295 |
+
add_special_tokens=add_special_tokens,
|
296 |
+
)
|
297 |
+
return model_config, tokenizer_config
|
298 |
+
|
299 |
+
|
300 |
+
def create_video2world_model_config(
|
301 |
+
model_ckpt_path: str,
|
302 |
+
tokenizer_ckpt_path: str,
|
303 |
+
tensor_model_parallel_size: int = 1,
|
304 |
+
model_family: str = "cosmos",
|
305 |
+
model_size: str = "4b",
|
306 |
+
pixel_chunk_duration: int = 9,
|
307 |
+
num_video_frames: int = 36,
|
308 |
+
compression_ratio: List[int] = [8, 16, 16],
|
309 |
+
original_seq_len: int = 8192,
|
310 |
+
num_condition_latents_t: int = 1,
|
311 |
+
num_tokens_to_ignore: int = -1,
|
312 |
+
batch_size: int = 2,
|
313 |
+
video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
|
314 |
+
rope_dim: str = "3D",
|
315 |
+
add_special_tokens: bool = True,
|
316 |
+
video_height: int = 384,
|
317 |
+
video_width: int = 640,
|
318 |
+
use_qk_normalization: bool = True,
|
319 |
+
insert_cross_attn: bool = False,
|
320 |
+
insert_cross_attn_every_k_layers: int = 1,
|
321 |
+
context_dim: int = 1024,
|
322 |
+
training_type: str = "video_to_video",
|
323 |
+
pad_to_multiple_of: Optional[int] = 64,
|
324 |
+
vocab_size: int = 64000,
|
325 |
+
apply_abs_pos_emb: bool = False,
|
326 |
+
) -> dict:
|
327 |
+
"""Create a video-to-world model config.
|
328 |
+
Args:
|
329 |
+
tensor_model_parallel_size (int): Number of tensor model parallel groups.
|
330 |
+
model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
|
331 |
+
model_size (str): Model size. Choices: "1b", "8b", "3b".
|
332 |
+
pixel_chunk_duration (int): Number of frames in each chunk.
|
333 |
+
num_video_frames (int): Number of video frames.
|
334 |
+
compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
|
335 |
+
original_seq_len (int): Original sequence length.
|
336 |
+
apply_yarn (bool): Whether to apply YaRN for long context scaling.
|
337 |
+
yarn_beta_fast (Optional[int]): Fast beta for YaRN.
|
338 |
+
yarn_beta_slow (Optional[int]): Slow beta for YaRN.
|
339 |
+
yarn_scale (Optional[int]): Scale factor for ctx extension.
|
340 |
+
use_qk_normalization (bool): Whether to use Query-Key normalization.
|
341 |
+
training_type (str): Type of training task.
|
342 |
+
batch_size (int): Batch size.
|
343 |
+
video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
|
344 |
+
video_tokenizer_version (str): Version of the video tokenizer.
|
345 |
+
num_condition_latents_t (int): Number of conditioning latent channels
|
346 |
+
num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
|
347 |
+
video_height (int): Height of the video frame. Defaults to 384.
|
348 |
+
video_width (int): Width of the video frame. Defaults to 640.
|
349 |
+
rope_dim (str): RoPE dimension. Choices: "1D", "3D".
|
350 |
+
add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
|
351 |
+
pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
|
352 |
+
vocab_size (int): Vocabulary size.
|
353 |
+
apply_abs_pos_emb (bool): Whether to apply absolute positional embeddings.
|
354 |
+
Returns:
|
355 |
+
dict: A dictionary containing the model configuration representing the model object, can be instantiated.
|
356 |
+
"""
|
357 |
+
assert (
|
358 |
+
pixel_chunk_duration % compression_ratio[0] == 1
|
359 |
+
), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
|
360 |
+
latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
|
361 |
+
latent_height = video_height // compression_ratio[1]
|
362 |
+
latent_width = video_width // compression_ratio[2]
|
363 |
+
# Do some math to compute the video latent shape and sequence length
|
364 |
+
assert (
|
365 |
+
num_video_frames % pixel_chunk_duration == 0
|
366 |
+
), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
|
367 |
+
video_latent_shape = [
|
368 |
+
num_video_frames // pixel_chunk_duration * latent_chunk_duration,
|
369 |
+
latent_height,
|
370 |
+
latent_width,
|
371 |
+
]
|
372 |
+
# product of video_latent_shape
|
373 |
+
num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
|
374 |
+
if add_special_tokens:
|
375 |
+
seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
|
376 |
+
seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
|
377 |
+
# for text to video, we need to add <bov> token to indicate the start of the video
|
378 |
+
elif training_type == "text_to_video":
|
379 |
+
seq_len = num_token_video_latent + 1
|
380 |
+
else:
|
381 |
+
seq_len = num_token_video_latent
|
382 |
+
|
383 |
+
if seq_len % pad_to_multiple_of != 0:
|
384 |
+
# Round up to the nearest multiple of pad_to_multiple_of
|
385 |
+
seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
|
386 |
+
|
387 |
+
# Model size specific parameters
|
388 |
+
model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
|
389 |
+
|
390 |
+
# Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
|
391 |
+
# If num_tokens_to_ignore is specified, use it.
|
392 |
+
# Else compute it from num_condition_latents_t
|
393 |
+
if num_tokens_to_ignore < 0:
|
394 |
+
num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
|
395 |
+
if not add_special_tokens and num_condition_latents_t > 0:
|
396 |
+
# If there are no special tokens (bov), do a -1 so that you can compute the loss
|
397 |
+
# from the first token of the next chunk
|
398 |
+
num_tokens_to_ignore -= 1
|
399 |
+
|
400 |
+
model_config = ModelConfig(
|
401 |
+
video_height=video_height,
|
402 |
+
video_width=video_width,
|
403 |
+
max_seq_len=seq_len,
|
404 |
+
max_batch_size=batch_size,
|
405 |
+
precision="bfloat16",
|
406 |
+
ckpt_path=model_ckpt_path,
|
407 |
+
use_qk_normalization=use_qk_normalization,
|
408 |
+
vocab_size=64000,
|
409 |
+
original_seq_len=original_seq_len,
|
410 |
+
tensor_model_parallel_size=tensor_model_parallel_size,
|
411 |
+
video_latent_shape=video_latent_shape,
|
412 |
+
num_video_frames=num_video_frames,
|
413 |
+
rope_dim=rope_dim,
|
414 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
415 |
+
insert_cross_attn=insert_cross_attn,
|
416 |
+
insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
|
417 |
+
context_dim=context_dim,
|
418 |
+
apply_abs_pos_emb=apply_abs_pos_emb,
|
419 |
+
**model_arch_specs,
|
420 |
+
)
|
421 |
+
|
422 |
+
video_tokenizer_config = video_tokenizer_config_creator(
|
423 |
+
tokenizer_ckpt_path, pixel_chunk_duration, compression_ratio
|
424 |
+
)
|
425 |
+
tokenizer_config = TokenizerConfig(
|
426 |
+
text_tokenizer=None,
|
427 |
+
video_tokenizer=VideoTokenizerConfig(
|
428 |
+
config=video_tokenizer_config,
|
429 |
+
data_key="video",
|
430 |
+
tokenizer_offset=0, # Since there is no text embeddings in the model. Note this only apply when the model is trained from scratch. If we use text pretrained model, the offset will be vocab_size of text token.
|
431 |
+
tokenize_here=True,
|
432 |
+
max_seq_len=num_token_video_latent,
|
433 |
+
vocab_size=vocab_size,
|
434 |
+
),
|
435 |
+
seq_len=seq_len,
|
436 |
+
training_type=training_type,
|
437 |
+
add_special_tokens=add_special_tokens,
|
438 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
439 |
+
)
|
440 |
+
return model_config, tokenizer_config
|
441 |
+
|
442 |
+
|
443 |
+
def create_video2world_model(
|
444 |
+
tensor_model_parallel_size: int = 1,
|
445 |
+
context_parallel_size: int = 1,
|
446 |
+
shard_checkpoint: bool = False,
|
447 |
+
model_family: str = "cosmos",
|
448 |
+
model_size: str = "1b",
|
449 |
+
backend: str = "pytorch",
|
450 |
+
pixel_chunk_duration: int = 9,
|
451 |
+
num_video_frames: int = 36,
|
452 |
+
compression_ratio: List[int] = [8, 16, 16],
|
453 |
+
original_seq_len: int = 8192,
|
454 |
+
apply_yarn: bool = False,
|
455 |
+
yarn_beta_fast: Optional[int] = None,
|
456 |
+
yarn_beta_slow: Optional[int] = None,
|
457 |
+
yarn_scale: Optional[int] = None,
|
458 |
+
num_condition_latents_t: int = 1,
|
459 |
+
num_tokens_to_ignore: int = -1,
|
460 |
+
batch_size: int = 1,
|
461 |
+
fsdp_enabled: bool = False,
|
462 |
+
act_ckpt_enabled: bool = False,
|
463 |
+
video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
|
464 |
+
rope_dim: str = "3D",
|
465 |
+
add_special_tokens: bool = False,
|
466 |
+
video_height: int = 384,
|
467 |
+
video_width: int = 640,
|
468 |
+
original_latent_shape: Optional[List[int]] = None,
|
469 |
+
use_qk_normalization: bool = True,
|
470 |
+
sequence_parallel: bool = False,
|
471 |
+
insert_cross_attn: bool = False,
|
472 |
+
insert_cross_attn_every_k_layers: int = 1,
|
473 |
+
context_dim: int = 1024,
|
474 |
+
finetune_layers_with_cross_attn: bool = False,
|
475 |
+
finetune_layers_without_cross_attn: bool = False,
|
476 |
+
use_action_condition: bool = False,
|
477 |
+
action_embedding_mode: Optional[str] = "mlp",
|
478 |
+
action_dim: int = 8, # ACTION_DIM,
|
479 |
+
action_embedding_dim: int = 1024,
|
480 |
+
group_causal_mask_mode: Optional[str] = None,
|
481 |
+
training_type: str = "video_to_video",
|
482 |
+
pad_to_multiple_of: Optional[int] = 1,
|
483 |
+
z_loss_coeff: float = 1e-4,
|
484 |
+
temporal_overlap: int = 0,
|
485 |
+
embedding_dropout: float = 0.0,
|
486 |
+
insert_medusa_head: bool = False,
|
487 |
+
ft_medusa_option: str = "fft",
|
488 |
+
medusa_num_heads: int = 7,
|
489 |
+
medusa_num_layers: int = 1,
|
490 |
+
medusa_concat_heads: bool = True,
|
491 |
+
fuse_qkv: bool = False,
|
492 |
+
zero_init_cross_attn_proj: bool = False,
|
493 |
+
concat_action_to_context: bool = False,
|
494 |
+
tokenizer_ckpt_path: str = "checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/ema.jit",
|
495 |
+
) -> dict:
|
496 |
+
"""Create a video-to-video model for training.
|
497 |
+
Args:
|
498 |
+
tensor_model_parallel_size (int): Number of tensor model parallel groups.
|
499 |
+
context_parallel_size (int): Number of context parallel groups.
|
500 |
+
model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
|
501 |
+
model_size (str): Model size. Choices: "1b", "8b", "3b".
|
502 |
+
backend (str): Backend for the model. Choices: "pytorch", "transformer_engine".
|
503 |
+
pixel_chunk_duration (int): Number of frames in each chunk.
|
504 |
+
num_video_frames (int): Number of video frames.
|
505 |
+
compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
|
506 |
+
original_seq_len (int): Original sequence length.
|
507 |
+
apply_yarn (bool): Whether to apply YaRN for long context scaling.
|
508 |
+
yarn_beta_fast (Optional[int]): Fast beta for YaRN.
|
509 |
+
yarn_beta_slow (Optional[int]): Slow beta for YaRN.
|
510 |
+
yarn_scale (Optional[int]): Scale factor for ctx extension.
|
511 |
+
fsdp_enabled (bool): Whether Fully Sharded Data Parallel (FSDP) is enabled.
|
512 |
+
act_ckpt_enabled (bool): Whether activation checkpointing is enabled.
|
513 |
+
use_qk_normalization (bool): Whether to use Query-Key normalization.
|
514 |
+
training_type (str): Type of training task.
|
515 |
+
batch_size (int): Batch size.
|
516 |
+
video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
|
517 |
+
video_tokenizer_version (str): Version of the video tokenizer.
|
518 |
+
num_condition_latents_t (int): Number of conditioning latent channels
|
519 |
+
num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
|
520 |
+
video_height (int): Height of the video frame. Defaults to 384.
|
521 |
+
video_width (int): Width of the video frame. Defaults to 640.
|
522 |
+
rope_dim (str): RoPE dimension. Choices: "1D", "2D", "3D".
|
523 |
+
add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
|
524 |
+
original_latent_shape (list): Original latent shape before RoPE scaling.
|
525 |
+
sequence_parallel (bool): Whether to enable sequence parallelism.
|
526 |
+
insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
|
527 |
+
insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
|
528 |
+
context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
|
529 |
+
finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn).
|
530 |
+
finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn).
|
531 |
+
use_action_condition (bool): Whether to use action condition.
|
532 |
+
action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp".
|
533 |
+
action_dim (int): Dimension of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]).
|
534 |
+
action_embedding_dim (int): Dimension of the action embedding.
|
535 |
+
group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal".
|
536 |
+
pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
|
537 |
+
z_loss_coeff (float): Coefficient for the z loss.
|
538 |
+
temporal_overlap (int): Temporal overlap in the latent space.
|
539 |
+
embedding_dropout (float): Dropout rate for the embeddings.
|
540 |
+
insert_medusa_head (bool): Whether to insert the Medusa head.
|
541 |
+
ft_medusa_option (str): Options on which layers to finetune, choices like:
|
542 |
+
"fft": fully fine-tune both medusa heads and all LLM backbone;
|
543 |
+
"head": fine-tune medusa heads;
|
544 |
+
"head_out": fine-tune medusa heads, and the output layer;
|
545 |
+
"head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone.
|
546 |
+
medusa_num_heads (int): Number of heads in the Medusa head.
|
547 |
+
medusa_num_layers (int): Number of layers in the Medusa head.
|
548 |
+
medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1.
|
549 |
+
fuse_qkv (bool): Whether to fuse the QKV linear layers.
|
550 |
+
zero_init_cross_attn_proj (bool): Whether to zero-initialize the cross-attention projection weights (default False).
|
551 |
+
concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False).
|
552 |
+
Returns:
|
553 |
+
dict: A dictionary containing the model configuration representing the model object, can be instantiated.
|
554 |
+
"""
|
555 |
+
assert (
|
556 |
+
pixel_chunk_duration % compression_ratio[0] == 1
|
557 |
+
), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
|
558 |
+
latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
|
559 |
+
latent_height = video_height // compression_ratio[1]
|
560 |
+
latent_width = video_width // compression_ratio[2]
|
561 |
+
# Compute the video latent shape and sequence length
|
562 |
+
if temporal_overlap == 0:
|
563 |
+
assert (
|
564 |
+
num_video_frames % pixel_chunk_duration == 0
|
565 |
+
), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
|
566 |
+
video_latent_shape = [
|
567 |
+
num_video_frames // pixel_chunk_duration * latent_chunk_duration,
|
568 |
+
latent_height,
|
569 |
+
latent_width,
|
570 |
+
]
|
571 |
+
|
572 |
+
else:
|
573 |
+
# Calculate temporal overlap in the latent space
|
574 |
+
temporal_overlap_latent = temporal_overlap // compression_ratio[0]
|
575 |
+
|
576 |
+
# Calculate the effective number of latent chunks for the video
|
577 |
+
latent_chunks = (num_video_frames - temporal_overlap) // (pixel_chunk_duration - temporal_overlap)
|
578 |
+
|
579 |
+
# Compute the total duration of the latent chunks, accounting for overlap
|
580 |
+
effective_latent_duration = (
|
581 |
+
latent_chunk_duration - temporal_overlap_latent
|
582 |
+
) * latent_chunks + temporal_overlap_latent
|
583 |
+
|
584 |
+
# Define the shape of the video in the latent space
|
585 |
+
video_latent_shape = [
|
586 |
+
effective_latent_duration, # Temporal dimension
|
587 |
+
latent_height, # Height in the latent space
|
588 |
+
latent_width, # Width in the latent space
|
589 |
+
]
|
590 |
+
|
591 |
+
# product of video_latent_shape
|
592 |
+
num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
|
593 |
+
if add_special_tokens:
|
594 |
+
seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
|
595 |
+
seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
|
596 |
+
# for text to video, we need to add <bov> token to indicate the start of the video
|
597 |
+
elif training_type == "text_to_video":
|
598 |
+
seq_len = num_token_video_latent + 1
|
599 |
+
else:
|
600 |
+
seq_len = num_token_video_latent
|
601 |
+
|
602 |
+
if seq_len % pad_to_multiple_of != 0:
|
603 |
+
# Round up to the nearest multiple of pad_to_multiple_of
|
604 |
+
seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
|
605 |
+
|
606 |
+
# Model size specific parameters
|
607 |
+
model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=False)
|
608 |
+
|
609 |
+
inference = False # False for training, True for inference
|
610 |
+
# set_parallel_mode = True
|
611 |
+
set_parallel_mode = tensor_model_parallel_size > 1
|
612 |
+
attention_tp = True
|
613 |
+
|
614 |
+
if context_parallel_size > 1:
|
615 |
+
assert backend == "transformer_engine", "Context parallelism is only supported in transformer engine."
|
616 |
+
|
617 |
+
if tensor_model_parallel_size > 1:
|
618 |
+
assert set_parallel_mode, "Tensor model parallelism is only supported in parallel mode."
|
619 |
+
|
620 |
+
# Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
|
621 |
+
# If num_tokens_to_ignore is specified, use it.
|
622 |
+
# Else compute it from num_condition_latents_t
|
623 |
+
if num_tokens_to_ignore < 0:
|
624 |
+
num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
|
625 |
+
if not add_special_tokens and num_condition_latents_t > 0:
|
626 |
+
# If there are no special tokens (bov), do a -1 so that you can compute the loss
|
627 |
+
# from the first token of the next chunk
|
628 |
+
num_tokens_to_ignore -= 1
|
629 |
+
|
630 |
+
model_config = TrainingModelConfig(
|
631 |
+
video_height=video_height,
|
632 |
+
video_width=video_width,
|
633 |
+
max_seq_len=seq_len,
|
634 |
+
max_batch_size=batch_size,
|
635 |
+
inference=inference,
|
636 |
+
backend=backend,
|
637 |
+
precision="bfloat16",
|
638 |
+
ema=EMAConfig(enabled=False),
|
639 |
+
act_ckpt_enabled=act_ckpt_enabled,
|
640 |
+
fsdp_enabled=fsdp_enabled,
|
641 |
+
cache_dir=None,
|
642 |
+
ckpt_path="checkpoints/Cosmos-Predict1-4B/model.pt",
|
643 |
+
use_qk_normalization=use_qk_normalization,
|
644 |
+
vocab_size=64000,
|
645 |
+
ignore_first_num_tokens=num_tokens_to_ignore,
|
646 |
+
apply_yarn=apply_yarn,
|
647 |
+
yarn_beta_fast=yarn_beta_fast,
|
648 |
+
yarn_beta_slow=yarn_beta_slow,
|
649 |
+
original_seq_len=original_seq_len,
|
650 |
+
yarn_scale=yarn_scale,
|
651 |
+
context_parallel_size=context_parallel_size,
|
652 |
+
tensor_model_parallel_size=tensor_model_parallel_size,
|
653 |
+
set_parallel_mode=set_parallel_mode,
|
654 |
+
attention_tp=attention_tp,
|
655 |
+
video_latent_shape=video_latent_shape,
|
656 |
+
num_video_frames=num_video_frames,
|
657 |
+
rope_dim=rope_dim,
|
658 |
+
original_latent_shape=original_latent_shape,
|
659 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
660 |
+
sequence_parallel=sequence_parallel,
|
661 |
+
insert_cross_attn=insert_cross_attn,
|
662 |
+
insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
|
663 |
+
context_dim=context_dim,
|
664 |
+
finetune_layers_with_cross_attn=finetune_layers_with_cross_attn,
|
665 |
+
finetune_layers_without_cross_attn=finetune_layers_without_cross_attn,
|
666 |
+
use_action_condition=use_action_condition,
|
667 |
+
action_embedding_mode=action_embedding_mode,
|
668 |
+
action_dim=action_dim,
|
669 |
+
action_embedding_dim=action_embedding_dim,
|
670 |
+
group_causal_mask_mode=group_causal_mask_mode,
|
671 |
+
z_loss_coeff=z_loss_coeff,
|
672 |
+
embedding_dropout=embedding_dropout,
|
673 |
+
insert_medusa_head=insert_medusa_head,
|
674 |
+
ft_medusa_option=ft_medusa_option,
|
675 |
+
medusa_num_heads=medusa_num_heads,
|
676 |
+
medusa_num_layers=medusa_num_layers,
|
677 |
+
medusa_concat_heads=medusa_concat_heads,
|
678 |
+
fuse_qkv=fuse_qkv,
|
679 |
+
zero_init_cross_attn_proj=zero_init_cross_attn_proj,
|
680 |
+
concat_action_to_context=concat_action_to_context,
|
681 |
+
**model_arch_specs,
|
682 |
+
)
|
683 |
+
|
684 |
+
tokenizer_config = TokenizerConfig(
|
685 |
+
text_tokenizer=None,
|
686 |
+
video_tokenizer=VideoTokenizerConfig(
|
687 |
+
config=video_tokenizer_config_creator(
|
688 |
+
ckpt_path=tokenizer_ckpt_path, pixel_chunk_duration=pixel_chunk_duration
|
689 |
+
),
|
690 |
+
data_key="video",
|
691 |
+
tokenizer_offset=0,
|
692 |
+
vocab_size=64000,
|
693 |
+
tokenize_here=True,
|
694 |
+
max_seq_len=num_token_video_latent,
|
695 |
+
temporal_overlap=temporal_overlap,
|
696 |
+
),
|
697 |
+
seq_len="${model.model_config.max_seq_len}",
|
698 |
+
training_type=training_type,
|
699 |
+
add_special_tokens=add_special_tokens,
|
700 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
701 |
+
)
|
702 |
+
|
703 |
+
model_parallel = ModelParallelConfig(
|
704 |
+
bf16=True,
|
705 |
+
params_dtype=getattr(torch, "bfloat16"),
|
706 |
+
)
|
707 |
+
model_parallel.tensor_model_parallel_size = "${model.model_config.tensor_model_parallel_size}"
|
708 |
+
model_parallel.context_parallel_size = "${model.model_config.context_parallel_size}"
|
709 |
+
model_parallel.sequence_parallel = "${model.model_config.sequence_parallel}"
|
710 |
+
return L(AutoRegressiveTrainingModel.build)(
|
711 |
+
seed=0,
|
712 |
+
train_from_scratch=True,
|
713 |
+
model_config=model_config,
|
714 |
+
fsdp_checkpointer=None,
|
715 |
+
tokenizer_config=tokenizer_config,
|
716 |
+
model_parallel=model_parallel,
|
717 |
+
shard_checkpoint=shard_checkpoint,
|
718 |
+
)
|
cosmos_predict1/autoregressive/configs/base/model_parallel.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from megatron.core import ModelParallelConfig
|
18 |
+
|
19 |
+
from cosmos_predict1.utils.lazy_config import LazyDict
|
20 |
+
|
21 |
+
|
22 |
+
def create_model_parallel_config():
|
23 |
+
model_parallel = ModelParallelConfig(bf16=True, params_dtype=getattr(torch, "bfloat16"))
|
24 |
+
model_parallel.tensor_model_parallel_size = "${model.model_parallel.tensor_model_parallel_size}"
|
25 |
+
model_parallel.context_parallel_size = "${model.model_parallel.context_parallel_size}"
|
26 |
+
model_parallel.sequence_parallel = "${model.model_parallel.sequence_parallel}"
|
27 |
+
MODEL_PARALLELS = LazyDict(
|
28 |
+
dict(
|
29 |
+
model_parallel_bf16=model_parallel,
|
30 |
+
),
|
31 |
+
flags={"allow_objects": True},
|
32 |
+
)
|
33 |
+
return MODEL_PARALLELS["model_parallel_bf16"]
|
cosmos_predict1/autoregressive/configs/base/optim.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
19 |
+
|
20 |
+
|
21 |
+
class LambdaLinearWarmupScheduler:
|
22 |
+
"""
|
23 |
+
A learning rate scheduler that implements linear warm-up and cool-down.
|
24 |
+
|
25 |
+
This scheduler provides three phases:
|
26 |
+
1. Warm-up: Learning rate linearly increases from 0 to 1.
|
27 |
+
2. Constant: Learning rate remains at 1.
|
28 |
+
3. Cool-down: Learning rate linearly decreases from 1 to 0.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
warmup_steps (int): Number of steps for the warm-up phase.
|
32 |
+
warmup_offset (int): Starts warmup from this offset.
|
33 |
+
max_iter (int, optional): Total number of iterations. Required if cooldown_steps is provided.
|
34 |
+
cooldown_steps (int, optional): Number of steps for the cool-down phase.
|
35 |
+
|
36 |
+
Raises:
|
37 |
+
ValueError: If cooldown_steps is provided without max_iter, or if an invalid step is given.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, warmup_steps: int, warmup_offset: int = 0, max_iter: int = None, cooldown_steps: int = None):
|
41 |
+
self.warmup_steps = warmup_steps
|
42 |
+
self.warmup_offset = warmup_offset
|
43 |
+
self.max_iter = max_iter
|
44 |
+
self.cooldown_steps = cooldown_steps
|
45 |
+
|
46 |
+
if cooldown_steps is not None:
|
47 |
+
if max_iter is None:
|
48 |
+
raise ValueError("max_iter must be specified when cooldown_steps is provided")
|
49 |
+
self.cooldown_start = max_iter - cooldown_steps
|
50 |
+
else:
|
51 |
+
self.cooldown_start = None
|
52 |
+
|
53 |
+
def __call__(self, step):
|
54 |
+
# Warm-up phase
|
55 |
+
if step < self.warmup_offset:
|
56 |
+
return 0
|
57 |
+
|
58 |
+
if step < self.warmup_steps + self.warmup_offset:
|
59 |
+
return float(step - self.warmup_offset) / float(max(1, self.warmup_steps))
|
60 |
+
|
61 |
+
# Constant phase (no cool-down)
|
62 |
+
elif self.cooldown_steps is None:
|
63 |
+
return 1.0
|
64 |
+
|
65 |
+
# Constant phase (before cool-down starts)
|
66 |
+
elif step < self.cooldown_start:
|
67 |
+
return 1.0
|
68 |
+
|
69 |
+
# Cool-down phase
|
70 |
+
elif self.cooldown_start <= step < self.max_iter:
|
71 |
+
cooldown_progress = (step - self.cooldown_start) / self.cooldown_steps
|
72 |
+
return 1.0 - cooldown_progress
|
73 |
+
|
74 |
+
# After max_iter
|
75 |
+
elif step >= self.max_iter:
|
76 |
+
return 0.0
|
77 |
+
|
78 |
+
# Unexpected case
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Invalid step {step}")
|
81 |
+
|
82 |
+
|
83 |
+
LambdaLinearLR = L(torch.optim.lr_scheduler.LambdaLR)(
|
84 |
+
optimizer=None,
|
85 |
+
lr_lambda=L(LambdaLinearWarmupScheduler)(warmup_steps=5000),
|
86 |
+
)
|
cosmos_predict1/autoregressive/configs/base/tokenizer.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQStateDictTokenizer
|
21 |
+
from cosmos_predict1.autoregressive.tokenizer.networks import CausalDiscreteVideoTokenizer
|
22 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
23 |
+
from cosmos_predict1.utils.lazy_config import LazyDict
|
24 |
+
|
25 |
+
|
26 |
+
def create_discrete_video_fsq_tokenizer_state_dict_config(
|
27 |
+
ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16]
|
28 |
+
) -> LazyDict:
|
29 |
+
CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)(
|
30 |
+
# The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime.
|
31 |
+
# - It relies on fully 3D discrete wavelet transform
|
32 |
+
# - Uses a layer norm instead of a group norm
|
33 |
+
# - Factorizes full convolutions into spatial and temporal convolutions
|
34 |
+
# - Factorizes full attention into spatial and temporal attention
|
35 |
+
# - Strictly causal, with flexible temporal length at inference.
|
36 |
+
attn_resolutions=[32],
|
37 |
+
channels=128,
|
38 |
+
channels_mult=[2, 4, 4],
|
39 |
+
dropout=0.0,
|
40 |
+
in_channels=3,
|
41 |
+
num_res_blocks=2,
|
42 |
+
out_channels=3,
|
43 |
+
resolution=1024,
|
44 |
+
patch_size=4,
|
45 |
+
patch_method="haar",
|
46 |
+
z_channels=16,
|
47 |
+
z_factor=1,
|
48 |
+
num_groups=1,
|
49 |
+
legacy_mode=False,
|
50 |
+
spatial_compression=16,
|
51 |
+
temporal_compression=8,
|
52 |
+
embedding_dim=6,
|
53 |
+
levels=[8, 8, 8, 5, 5, 5],
|
54 |
+
name="CausalDiscreteFactorizedVideoTokenizer",
|
55 |
+
)
|
56 |
+
|
57 |
+
return L(DiscreteVideoFSQStateDictTokenizer)(
|
58 |
+
enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"),
|
59 |
+
dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"),
|
60 |
+
tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig,
|
61 |
+
name="discrete_video_fsq",
|
62 |
+
latent_ch=6,
|
63 |
+
is_bf16=True,
|
64 |
+
pixel_chunk_duration=pixel_chunk_duration,
|
65 |
+
latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0],
|
66 |
+
max_enc_batch_size=8,
|
67 |
+
max_dec_batch_size=4,
|
68 |
+
levels=[8, 8, 8, 5, 5, 5],
|
69 |
+
compression_ratio=compression_ratio,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
@attrs.define(slots=False)
|
74 |
+
class TextTokenizerConfig:
|
75 |
+
"""
|
76 |
+
Text tokenizer config
|
77 |
+
|
78 |
+
Args:
|
79 |
+
config: Config file to define the text tokenizer class.
|
80 |
+
data_key (str): The input key from data_dict that will be passed to the text tokenizer.
|
81 |
+
tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
|
82 |
+
tokenizer_offset (int): Offset that is added to the tokens.
|
83 |
+
vocab_size (int): Vocabulary size of the tokenizer.
|
84 |
+
"""
|
85 |
+
|
86 |
+
config: LazyDict
|
87 |
+
data_key: str = ""
|
88 |
+
tokenize_here: bool = False
|
89 |
+
tokenizer_offset: int = 0
|
90 |
+
vocab_size: int = 0
|
91 |
+
|
92 |
+
|
93 |
+
@attrs.define(slots=False)
|
94 |
+
class VideoTokenizerConfig:
|
95 |
+
"""
|
96 |
+
Video tokenizer config
|
97 |
+
|
98 |
+
Args:
|
99 |
+
config: Config file to define the video tokenizer class.
|
100 |
+
data_key (str): The input key from data_dict that will be passed to the video tokenizer.
|
101 |
+
tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
|
102 |
+
tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we
|
103 |
+
add an offset to make sure that video tokens and text tokens don't overlap.
|
104 |
+
vocab_size (int): Vocabulary size of the tokenizer.
|
105 |
+
max_seq_len (int): Maximum token length for an input video.
|
106 |
+
temporal_overlap (int): Overlap between consecutive video chunks.
|
107 |
+
"""
|
108 |
+
|
109 |
+
config: LazyDict
|
110 |
+
data_key: str = ""
|
111 |
+
tokenize_here: bool = True
|
112 |
+
tokenizer_offset: int = 0
|
113 |
+
vocab_size: int = 0
|
114 |
+
max_seq_len: int = -1
|
115 |
+
temporal_overlap: int = 0
|
116 |
+
|
117 |
+
|
118 |
+
@attrs.define(slots=False)
|
119 |
+
class TokenizerConfig:
|
120 |
+
"""
|
121 |
+
Joint tokenizer config
|
122 |
+
|
123 |
+
Args:
|
124 |
+
text_tokenizer (TextTokenizerConfig): Text tokenizer config file
|
125 |
+
class_tokenizer (ClassTokenizerConfig): Class tokenizer config file
|
126 |
+
video_tokenizer (VideoTokenizerConfig): Video tokenizer config file
|
127 |
+
image_tokenizer (ImageTokenizerConfig): Image tokenizer config file
|
128 |
+
seq_len (int): Final token sequence length
|
129 |
+
training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"]
|
130 |
+
add_special_tokens (bool): Whether to add special tokens to the output tokens
|
131 |
+
pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
|
132 |
+
"""
|
133 |
+
|
134 |
+
text_tokenizer: Optional[TextTokenizerConfig] = None
|
135 |
+
video_tokenizer: Optional[VideoTokenizerConfig] = None
|
136 |
+
seq_len: int = 4096
|
137 |
+
training_type: str = None
|
138 |
+
add_special_tokens: bool = True
|
139 |
+
pad_to_multiple_of: Optional[int] = 64
|
cosmos_predict1/autoregressive/configs/config.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Default config for cosmos_ar project."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
from typing import Any, List
|
20 |
+
|
21 |
+
import attrs
|
22 |
+
|
23 |
+
from cosmos_predict1.autoregressive.configs.registry import register_configs
|
24 |
+
from cosmos_predict1.autoregressive.trainer import Trainer
|
25 |
+
from cosmos_predict1.utils import config, log
|
26 |
+
from cosmos_predict1.utils.config_helper import import_all_modules_from_package
|
27 |
+
|
28 |
+
|
29 |
+
@attrs.define(slots=False)
|
30 |
+
class Config(config.Config):
|
31 |
+
defaults: List[Any] = attrs.field(
|
32 |
+
factory=lambda: [
|
33 |
+
"_self_",
|
34 |
+
{"model": None},
|
35 |
+
{"data_train": "mock_video"},
|
36 |
+
{"data_val": None},
|
37 |
+
{"optimizer": "fused_adamw"},
|
38 |
+
{"scheduler": "warmup_cosine_lr"},
|
39 |
+
{"checkpoint": "local"},
|
40 |
+
{"callbacks": "basic"},
|
41 |
+
{"global_config": None},
|
42 |
+
{"experiment": None},
|
43 |
+
]
|
44 |
+
)
|
45 |
+
|
46 |
+
def validate(self) -> None:
|
47 |
+
"""Validate that the config has all required fields."""
|
48 |
+
assert self.job.project != "", "job.project is not set"
|
49 |
+
assert self.job.group != "", "job.group is not set"
|
50 |
+
assert self.job.name != "", "job.name is not set"
|
51 |
+
log.info("Validating config for cosmos_autoregressive job")
|
52 |
+
# FSDP config check
|
53 |
+
if self.model.model_config.fsdp_enabled:
|
54 |
+
assert self.trainer.distributed_parallelism == "fsdp"
|
55 |
+
else:
|
56 |
+
assert self.trainer.distributed_parallelism == "ddp"
|
57 |
+
|
58 |
+
# Transformer Engine config check
|
59 |
+
if self.model.model_config.backend == "transformer_engine":
|
60 |
+
assert (
|
61 |
+
"NVTE_FLASH_ATTN" in os.environ and os.environ["NVTE_FLASH_ATTN"] == "1"
|
62 |
+
) # Enable Flash attention for transformer engine
|
63 |
+
|
64 |
+
# TP, CP config check
|
65 |
+
if self.model_parallel is not None:
|
66 |
+
if self.model_parallel.context_parallel_size > 1:
|
67 |
+
assert (
|
68 |
+
self.model.model_config.backend == "transformer_engine"
|
69 |
+
), "Context parallelism is only supported in transformer engine."
|
70 |
+
|
71 |
+
if self.model_parallel.tensor_model_parallel_size > 1:
|
72 |
+
assert (
|
73 |
+
self.model.model_config.set_parallel_mode
|
74 |
+
), "Tensor model parallelism is only supported in parallel mode."
|
75 |
+
|
76 |
+
if self.model_parallel.sequence_parallel:
|
77 |
+
assert (
|
78 |
+
self.model_parallel.tensor_model_parallel_size > 1
|
79 |
+
), "Sequence parallelism is only supported in tensor model parallelism."
|
80 |
+
assert (
|
81 |
+
self.model.model_config.backend == "transformer_engine"
|
82 |
+
), "Sequence parallelism is only supported in transformer engine."
|
83 |
+
|
84 |
+
|
85 |
+
def make_config():
|
86 |
+
c = Config(
|
87 |
+
model=None,
|
88 |
+
optimizer=None,
|
89 |
+
scheduler=None,
|
90 |
+
dataloader_train=None,
|
91 |
+
dataloader_val=None,
|
92 |
+
checkpoint=None,
|
93 |
+
)
|
94 |
+
|
95 |
+
c.job.project = "cosmos_autoregressive"
|
96 |
+
c.job.group = "debug"
|
97 |
+
c.job.name = "default_${now:%Y-%m-%d}_${now:%H-%M-%S}"
|
98 |
+
|
99 |
+
c.trainer.type = Trainer
|
100 |
+
c.trainer.run_validation = True
|
101 |
+
|
102 |
+
c.trainer.seed = 0
|
103 |
+
c.trainer.max_iter = 10
|
104 |
+
c.trainer.logging_iter = 1
|
105 |
+
|
106 |
+
c.trainer.callbacks = None
|
107 |
+
register_configs()
|
108 |
+
# experiment config are defined in the experiment folder
|
109 |
+
# call import_all_modules_from_package to register them
|
110 |
+
import_all_modules_from_package("cosmos_predict1.autoregressive.configs.experiment")
|
111 |
+
return c
|
cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py
ADDED
File without changes
|
cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""
|
17 |
+
This file contains a basic configuration for video2video experiments.
|
18 |
+
"""
|
19 |
+
|
20 |
+
from hydra.core.config_store import ConfigStore
|
21 |
+
|
22 |
+
from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model
|
23 |
+
from cosmos_predict1.autoregressive.configs.base.model_parallel import create_model_parallel_config
|
24 |
+
from cosmos_predict1.utils import log
|
25 |
+
from cosmos_predict1.utils.lazy_config import LazyDict
|
26 |
+
|
27 |
+
cs = ConfigStore.instance()
|
28 |
+
|
29 |
+
|
30 |
+
"""
|
31 |
+
Finetune 4B model with TP=1, pytorch backend, low resolution tealrobot data, frames 33, chunk 33.
|
32 |
+
Usage:
|
33 |
+
torchrun --nproc_per_node=1 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobotsmall_tp1
|
34 |
+
"""
|
35 |
+
base_4b_example_tealrobotsmall_tp1: LazyDict = LazyDict(
|
36 |
+
dict(
|
37 |
+
defaults=[
|
38 |
+
{"override /data_train": "tealrobot_video_small"},
|
39 |
+
{
|
40 |
+
"override /callbacks": [
|
41 |
+
"basic",
|
42 |
+
"video_teacher_forcing",
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{"override /checkpoint": "local"},
|
46 |
+
{"override /optimizer": "fused_adamw"},
|
47 |
+
{"override /scheduler": "warmup_cosine_lr"},
|
48 |
+
"_self_",
|
49 |
+
],
|
50 |
+
job=dict(
|
51 |
+
project="posttraining",
|
52 |
+
group="autoregressive_base",
|
53 |
+
name="base_4b_example_tealrobotsmall_tp1",
|
54 |
+
),
|
55 |
+
model=create_video2world_model(
|
56 |
+
model_size="4b",
|
57 |
+
model_family="cosmos",
|
58 |
+
backend="pytorch",
|
59 |
+
tensor_model_parallel_size=1,
|
60 |
+
batch_size=1,
|
61 |
+
pixel_chunk_duration=33,
|
62 |
+
num_video_frames=33,
|
63 |
+
video_height=384,
|
64 |
+
video_width=640,
|
65 |
+
tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit",
|
66 |
+
add_special_tokens=False,
|
67 |
+
),
|
68 |
+
trainer=dict(
|
69 |
+
max_iter=50000,
|
70 |
+
grad_accum_iter=1,
|
71 |
+
grad_scaler_args=dict(enabled=False),
|
72 |
+
run_validation=False, # No need for validation as epoch <= 1
|
73 |
+
distributed_parallelism="ddp",
|
74 |
+
callbacks=dict(
|
75 |
+
vid_sampling_tf=dict(
|
76 |
+
every_n=500,
|
77 |
+
),
|
78 |
+
),
|
79 |
+
),
|
80 |
+
checkpoint=dict(
|
81 |
+
load_path="checkpoints/Cosmos-Predict1-4B/model.pt",
|
82 |
+
load_training_state=False,
|
83 |
+
strict_resume=True,
|
84 |
+
save_iter=1000,
|
85 |
+
),
|
86 |
+
model_parallel=create_model_parallel_config(),
|
87 |
+
),
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
"""
|
92 |
+
Finetune 4B model with TP=4, pytorch backend, high resolution tealrobot data, frame 33, chunk 33.
|
93 |
+
Usage:
|
94 |
+
torchrun --nproc_per_node=4 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobot_tp4
|
95 |
+
"""
|
96 |
+
base_4b_example_tealrobot_tp4: LazyDict = LazyDict(
|
97 |
+
dict(
|
98 |
+
defaults=[
|
99 |
+
{"override /data_train": "tealrobot_video"},
|
100 |
+
{
|
101 |
+
"override /callbacks": [
|
102 |
+
"basic",
|
103 |
+
"video_teacher_forcing",
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{"override /checkpoint": "local"},
|
107 |
+
{"override /optimizer": "fused_adamw"},
|
108 |
+
{"override /scheduler": "warmup_cosine_lr"},
|
109 |
+
"_self_",
|
110 |
+
],
|
111 |
+
job=dict(
|
112 |
+
project="posttraining",
|
113 |
+
group="autoregressive_base",
|
114 |
+
name="base_4b_example_tealrobot_tp4",
|
115 |
+
),
|
116 |
+
model=create_video2world_model(
|
117 |
+
model_size="4b",
|
118 |
+
model_family="cosmos",
|
119 |
+
backend="pytorch",
|
120 |
+
tensor_model_parallel_size=4,
|
121 |
+
batch_size=1,
|
122 |
+
pixel_chunk_duration=33,
|
123 |
+
num_video_frames=33,
|
124 |
+
video_height=640,
|
125 |
+
video_width=848,
|
126 |
+
tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit",
|
127 |
+
add_special_tokens=False,
|
128 |
+
),
|
129 |
+
trainer=dict(
|
130 |
+
max_iter=50000,
|
131 |
+
grad_accum_iter=1,
|
132 |
+
grad_scaler_args=dict(enabled=False),
|
133 |
+
run_validation=False, # No need for validation as epoch <= 1
|
134 |
+
distributed_parallelism="ddp",
|
135 |
+
callbacks=dict(
|
136 |
+
vid_sampling_tf=dict(
|
137 |
+
every_n=500,
|
138 |
+
),
|
139 |
+
),
|
140 |
+
),
|
141 |
+
checkpoint=dict(
|
142 |
+
load_path="checkpoints/Cosmos-Predict1-4B/model.pt",
|
143 |
+
load_training_state=False,
|
144 |
+
strict_resume=False,
|
145 |
+
save_iter=1000,
|
146 |
+
),
|
147 |
+
model_parallel=create_model_parallel_config(),
|
148 |
+
),
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
def register_experiments(cs):
|
153 |
+
# Register the experiments
|
154 |
+
for _item in [
|
155 |
+
base_4b_example_tealrobotsmall_tp1,
|
156 |
+
base_4b_example_tealrobot_tp4,
|
157 |
+
]:
|
158 |
+
cs.store(
|
159 |
+
group="experiment",
|
160 |
+
package="_global_",
|
161 |
+
name=_item["job"]["name"],
|
162 |
+
node=_item,
|
163 |
+
)
|
cosmos_predict1/autoregressive/configs/inference/inference_config.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, List, Optional, Union
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TokenizerConfig
|
21 |
+
|
22 |
+
|
23 |
+
@attrs.define(slots=False)
|
24 |
+
class DataShapeConfig:
|
25 |
+
latent_shape: list = []
|
26 |
+
num_video_frames: Union[None, int] = None
|
27 |
+
height: Union[None, int] = None
|
28 |
+
width: Union[None, int] = None
|
29 |
+
|
30 |
+
|
31 |
+
@attrs.define(slots=False)
|
32 |
+
class SamplingConfig:
|
33 |
+
"""
|
34 |
+
Sampling config
|
35 |
+
Args:
|
36 |
+
temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
|
37 |
+
top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
|
38 |
+
logprobs (bool): Flag indicating whether to compute token log probabilities. Defaults to False.
|
39 |
+
echo (bool): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
|
40 |
+
|
41 |
+
"""
|
42 |
+
|
43 |
+
temperature: float = 0.6
|
44 |
+
top_k: int = None
|
45 |
+
top_p: float = 0.9
|
46 |
+
compile_prefill: bool = False
|
47 |
+
compile_sampling: bool = True
|
48 |
+
logprobs: bool = False
|
49 |
+
echo: bool = False
|
50 |
+
|
51 |
+
|
52 |
+
@attrs.define(slots=False)
|
53 |
+
class DiffusionDecoderSamplingConfig:
|
54 |
+
"""
|
55 |
+
Diffusion decoder sampling config
|
56 |
+
Args:
|
57 |
+
guidance (float): Guidance scale for the diffusion process. Controls how much the model follows the conditioning. Defaults to 0.8.
|
58 |
+
sigma_min (float): Minimum noise level for the diffusion process. Defaults to 0.02.
|
59 |
+
sigma (float): Initial noise level for the diffusion process. Defaults to 8.
|
60 |
+
num_steps (int): Number of denoising steps to perform. Defaults to 35.
|
61 |
+
overlap (int): Number of overlapping frames between video chunks during processing. Defaults to 2.
|
62 |
+
continuous_tokenizer_channel (int): Number of channels in the continuous tokenizer of diffusion decoder. Defaults to 16.
|
63 |
+
continuous_tokenizer_spatial_compression_ratio (int): Spatial compression ratio for the continuous tokenizer of diffusion decoder. Defaults to 8.
|
64 |
+
dd_train_num_video_frames (int): Number of video frames used during training for diffusion decoder. Defaults to 57.
|
65 |
+
"""
|
66 |
+
|
67 |
+
guidance: float = 1.8
|
68 |
+
sigma_min: float = 0.02
|
69 |
+
sigma: float = 8
|
70 |
+
num_steps: int = 15
|
71 |
+
overlap: int = 2
|
72 |
+
continuous_tokenizer_channel = 16
|
73 |
+
continuous_tokenizer_spatial_compression_ratio = 8
|
74 |
+
dd_train_num_video_frames: int = 57
|
75 |
+
max_iter: int = 99
|
76 |
+
fps: int = 24
|
77 |
+
|
78 |
+
|
79 |
+
@attrs.define(slots=False)
|
80 |
+
class InferenceConfig:
|
81 |
+
"""
|
82 |
+
Inference config
|
83 |
+
Args:
|
84 |
+
model_config (ModelConfig): Model config
|
85 |
+
tokenizer_config (TokenizerConfig): Tokenizer config
|
86 |
+
ckpt_path (str): Path to the checkpoint
|
87 |
+
latent_shape (list): Shape of the latent
|
88 |
+
"""
|
89 |
+
|
90 |
+
model_config: ModelConfig = None
|
91 |
+
tokenizer_config: TokenizerConfig = None
|
92 |
+
ckpt_path: str = ""
|
93 |
+
data_shape_config: DataShapeConfig = None
|
94 |
+
|
95 |
+
defaults: List[Any] = attrs.field(
|
96 |
+
factory=lambda: [
|
97 |
+
"_self_",
|
98 |
+
{"data_val": None},
|
99 |
+
{"data_shape_config": "video_shape_as_model_config"},
|
100 |
+
{"eval_job": None},
|
101 |
+
]
|
102 |
+
)
|
cosmos_predict1/autoregressive/configs/registry.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from hydra.core.config_store import ConfigStore
|
18 |
+
|
19 |
+
from cosmos_predict1.autoregressive.configs.base.callbacks import BASIC_CALLBACKS, VIDEO_TEACHER_FORCING_CALLBACK
|
20 |
+
from cosmos_predict1.autoregressive.configs.base.dataloader import get_tealrobot_video
|
21 |
+
from cosmos_predict1.autoregressive.configs.base.optim import LambdaLinearLR
|
22 |
+
from cosmos_predict1.autoregressive.configs.experiment.video2video.basic import register_experiments
|
23 |
+
from cosmos_predict1.utils import config, log
|
24 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
25 |
+
from cosmos_predict1.utils.scheduler import WarmupCosineLR
|
26 |
+
|
27 |
+
|
28 |
+
def register_checkpoint(cs):
|
29 |
+
checkpoint_local = config.CheckpointConfig(
|
30 |
+
save_iter=5000,
|
31 |
+
broadcast_via_filesystem=True,
|
32 |
+
)
|
33 |
+
cs.store(group="checkpoint", package="checkpoint", name="local", node=checkpoint_local)
|
34 |
+
|
35 |
+
|
36 |
+
def register_callbacks(cs):
|
37 |
+
cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS)
|
38 |
+
cs.store(
|
39 |
+
group="callbacks",
|
40 |
+
package="trainer.callbacks",
|
41 |
+
name="video_teacher_forcing",
|
42 |
+
node=VIDEO_TEACHER_FORCING_CALLBACK,
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def register_scheduler(cs):
|
47 |
+
cs.store(
|
48 |
+
group="scheduler",
|
49 |
+
package="scheduler",
|
50 |
+
name="warmup_cosine_lr",
|
51 |
+
node=L(WarmupCosineLR)(optimizer=None, warmup_iters=5000, lr_decay_iters="${trainer.max_iter}", min_lr=1e-8),
|
52 |
+
)
|
53 |
+
cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearLR)
|
54 |
+
|
55 |
+
|
56 |
+
def register_optimizer(cs):
|
57 |
+
cs.store(
|
58 |
+
group="optimizer",
|
59 |
+
package="optimizer",
|
60 |
+
name="fused_adamw",
|
61 |
+
node=L(torch.optim.AdamW)(params=None, lr=1e-3, weight_decay=0.05, fused=True),
|
62 |
+
)
|
63 |
+
cs.store(
|
64 |
+
group="optimizer",
|
65 |
+
package="optimizer",
|
66 |
+
name="sgd",
|
67 |
+
node=L(torch.optim.SGD)(params=None, lr=5e-6, momentum=0.9),
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def register_training_data(cs):
|
72 |
+
cs.store(
|
73 |
+
group="data_train",
|
74 |
+
package="dataloader_train",
|
75 |
+
name="tealrobot_video_small",
|
76 |
+
node=get_tealrobot_video(num_frames=33, video_size=[384, 640]),
|
77 |
+
)
|
78 |
+
cs.store(group="data_train", package="dataloader_train", name="tealrobot_video", node=get_tealrobot_video())
|
79 |
+
|
80 |
+
|
81 |
+
def register_configs():
|
82 |
+
log.info("Registering configs for autoregressive_base")
|
83 |
+
cs = ConfigStore.instance()
|
84 |
+
register_callbacks(cs)
|
85 |
+
register_checkpoint(cs)
|
86 |
+
register_optimizer(cs)
|
87 |
+
register_scheduler(cs)
|
88 |
+
register_training_data(cs)
|
89 |
+
register_experiments(cs)
|
cosmos_predict1/autoregressive/datasets/dataset_utils.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torchvision.transforms.functional as transforms_F
|
20 |
+
from PIL import Image
|
21 |
+
|
22 |
+
|
23 |
+
def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]:
|
24 |
+
r"""Function for obtaining the image size from the data dict.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
data_dict (dict): Input data dict
|
28 |
+
input_keys (list): List of input keys
|
29 |
+
Returns:
|
30 |
+
width (int): Width of the input image
|
31 |
+
height (int): Height of the input image
|
32 |
+
"""
|
33 |
+
|
34 |
+
data1 = data_dict[input_keys[0]]
|
35 |
+
if isinstance(data1, Image.Image):
|
36 |
+
width, height = data1.size
|
37 |
+
elif isinstance(data1, torch.Tensor):
|
38 |
+
height, width = data1.size()[-2:]
|
39 |
+
else:
|
40 |
+
raise ValueError("data to random crop should be PIL Image or tensor")
|
41 |
+
|
42 |
+
return width, height
|
43 |
+
|
44 |
+
|
45 |
+
class Augmentor:
|
46 |
+
def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
|
47 |
+
r"""Base augmentor class
|
48 |
+
|
49 |
+
Args:
|
50 |
+
input_keys (list): List of input keys
|
51 |
+
output_keys (list): List of output keys
|
52 |
+
args (dict): Arguments associated with the augmentation
|
53 |
+
"""
|
54 |
+
self.input_keys = input_keys
|
55 |
+
self.output_keys = output_keys
|
56 |
+
self.args = args
|
57 |
+
|
58 |
+
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
59 |
+
raise ValueError("Augmentor not implemented")
|
60 |
+
|
61 |
+
|
62 |
+
class ResizeSmallestSideAspectPreserving(Augmentor):
|
63 |
+
def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
|
64 |
+
super().__init__(input_keys, output_keys, args)
|
65 |
+
|
66 |
+
def __call__(self, data_dict: dict) -> dict:
|
67 |
+
r"""Performs aspect-ratio preserving resizing.
|
68 |
+
Image is resized to the dimension which has the smaller ratio of (size / target_size).
|
69 |
+
First we compute (w_img / w_target) and (h_img / h_target) and resize the image
|
70 |
+
to the dimension that has the smaller of these ratios.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
data_dict (dict): Input data dict
|
74 |
+
Returns:
|
75 |
+
data_dict (dict): Output dict where images are resized
|
76 |
+
"""
|
77 |
+
|
78 |
+
if self.output_keys is None:
|
79 |
+
self.output_keys = self.input_keys
|
80 |
+
assert self.args is not None, "Please specify args in augmentations"
|
81 |
+
|
82 |
+
img_w, img_h = self.args["img_w"], self.args["img_h"]
|
83 |
+
|
84 |
+
orig_w, orig_h = obtain_image_size(data_dict, self.input_keys)
|
85 |
+
scaling_ratio = max((img_w / orig_w), (img_h / orig_h))
|
86 |
+
target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5))
|
87 |
+
|
88 |
+
assert (
|
89 |
+
target_size[0] >= img_h and target_size[1] >= img_w
|
90 |
+
), f"Resize error. orig {(orig_w, orig_h)} desire {(img_w, img_h)} compute {target_size}"
|
91 |
+
|
92 |
+
for inp_key, out_key in zip(self.input_keys, self.output_keys):
|
93 |
+
data_dict[out_key] = transforms_F.resize(
|
94 |
+
data_dict[inp_key],
|
95 |
+
size=target_size, # type: ignore
|
96 |
+
interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC),
|
97 |
+
antialias=True,
|
98 |
+
)
|
99 |
+
|
100 |
+
if out_key != inp_key:
|
101 |
+
del data_dict[inp_key]
|
102 |
+
return data_dict
|
103 |
+
|
104 |
+
|
105 |
+
class CenterCrop(Augmentor):
|
106 |
+
def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
|
107 |
+
super().__init__(input_keys, output_keys, args)
|
108 |
+
|
109 |
+
def __call__(self, data_dict: dict) -> dict:
|
110 |
+
r"""Performs center crop.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
data_dict (dict): Input data dict
|
114 |
+
Returns:
|
115 |
+
data_dict (dict): Output dict where images are center cropped.
|
116 |
+
We also save the cropping parameters in the aug_params dict
|
117 |
+
so that it will be used by other transforms.
|
118 |
+
"""
|
119 |
+
assert (
|
120 |
+
(self.args is not None) and ("img_w" in self.args) and ("img_h" in self.args)
|
121 |
+
), "Please specify size in args"
|
122 |
+
|
123 |
+
img_w, img_h = self.args["img_w"], self.args["img_h"]
|
124 |
+
|
125 |
+
orig_w, orig_h = obtain_image_size(data_dict, self.input_keys)
|
126 |
+
for key in self.input_keys:
|
127 |
+
data_dict[key] = transforms_F.center_crop(data_dict[key], [img_h, img_w])
|
128 |
+
|
129 |
+
# We also add the aug params we use. This will be useful for other transforms
|
130 |
+
crop_x0 = (orig_w - img_w) // 2
|
131 |
+
crop_y0 = (orig_h - img_h) // 2
|
132 |
+
cropping_params = {
|
133 |
+
"resize_w": orig_w,
|
134 |
+
"resize_h": orig_h,
|
135 |
+
"crop_x0": crop_x0,
|
136 |
+
"crop_y0": crop_y0,
|
137 |
+
"crop_w": img_w,
|
138 |
+
"crop_h": img_h,
|
139 |
+
}
|
140 |
+
|
141 |
+
if "aug_params" not in data_dict:
|
142 |
+
data_dict["aug_params"] = dict()
|
143 |
+
|
144 |
+
data_dict["aug_params"]["cropping"] = cropping_params
|
145 |
+
data_dict["padding_mask"] = torch.zeros((1, cropping_params["crop_h"], cropping_params["crop_w"]))
|
146 |
+
return data_dict
|
147 |
+
|
148 |
+
|
149 |
+
class Normalize(Augmentor):
|
150 |
+
def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
|
151 |
+
super().__init__(input_keys, output_keys, args)
|
152 |
+
|
153 |
+
def __call__(self, data_dict: dict) -> dict:
|
154 |
+
r"""Performs data normalization.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
data_dict (dict): Input data dict
|
158 |
+
Returns:
|
159 |
+
data_dict (dict): Output dict where images are center cropped.
|
160 |
+
"""
|
161 |
+
assert self.args is not None, "Please specify args"
|
162 |
+
|
163 |
+
mean = self.args["mean"]
|
164 |
+
std = self.args["std"]
|
165 |
+
|
166 |
+
for key in self.input_keys:
|
167 |
+
if isinstance(data_dict[key], torch.Tensor):
|
168 |
+
data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255)
|
169 |
+
else:
|
170 |
+
data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor()
|
171 |
+
|
172 |
+
data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std)
|
173 |
+
return data_dict
|
cosmos_predict1/autoregressive/datasets/video_dataset.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""
|
17 |
+
Run this command to interactively debug:
|
18 |
+
PYTHONPATH=. python cosmos_predict1/autoregressive/datasets/video_dataset.py
|
19 |
+
"""
|
20 |
+
|
21 |
+
import os
|
22 |
+
import traceback
|
23 |
+
import warnings
|
24 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
from decord import VideoReader, cpu
|
29 |
+
from torch.utils.data import Dataset
|
30 |
+
from tqdm import tqdm
|
31 |
+
|
32 |
+
from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig
|
33 |
+
from cosmos_predict1.autoregressive.datasets.dataset_utils import (
|
34 |
+
CenterCrop,
|
35 |
+
Normalize,
|
36 |
+
ResizeSmallestSideAspectPreserving,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class VideoDataset(Dataset):
|
41 |
+
def __init__(self, config: VideoDatasetConfig):
|
42 |
+
"""Video Dataset class for loading video-to-video generation data."""
|
43 |
+
|
44 |
+
super().__init__()
|
45 |
+
self.dataset_dir = config.dataset_dir
|
46 |
+
self.sequence_interval = config.sequence_interval
|
47 |
+
self.sequence_length = config.num_frames
|
48 |
+
self.video_size = config.video_size
|
49 |
+
self.start_frame_interval = config.start_frame_interval
|
50 |
+
|
51 |
+
self.video_dir = self.dataset_dir
|
52 |
+
self.video_paths = [os.path.join(self.video_dir, f) for f in os.listdir(self.video_dir) if f.endswith(".mp4")]
|
53 |
+
print(f"{len(self.video_paths)} videos in total")
|
54 |
+
|
55 |
+
self.samples = self._init_samples(self.video_paths)
|
56 |
+
self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0]))
|
57 |
+
print(f"{len(self.samples)} samples in total")
|
58 |
+
self.wrong_number = 0
|
59 |
+
|
60 |
+
self.resize_transform = ResizeSmallestSideAspectPreserving(
|
61 |
+
input_keys=["video"],
|
62 |
+
args={"img_w": self.video_size[1], "img_h": self.video_size[0]},
|
63 |
+
)
|
64 |
+
self.crop_transform = CenterCrop(
|
65 |
+
input_keys=["video"],
|
66 |
+
args={"img_w": self.video_size[1], "img_h": self.video_size[0]},
|
67 |
+
)
|
68 |
+
self.normalize_transform = Normalize(
|
69 |
+
input_keys=["video"],
|
70 |
+
args={"mean": 0.5, "std": 0.5},
|
71 |
+
)
|
72 |
+
|
73 |
+
def __str__(self):
|
74 |
+
return f"{len(self.video_paths)} samples from {self.dataset_dir}"
|
75 |
+
|
76 |
+
def _init_samples(self, video_paths):
|
77 |
+
samples = []
|
78 |
+
with ThreadPoolExecutor(32) as executor:
|
79 |
+
future_to_video_path = {
|
80 |
+
executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths
|
81 |
+
}
|
82 |
+
for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)):
|
83 |
+
samples.extend(future.result())
|
84 |
+
return samples
|
85 |
+
|
86 |
+
def _load_and_process_video_path(self, video_path):
|
87 |
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
|
88 |
+
n_frames = len(vr)
|
89 |
+
|
90 |
+
samples = []
|
91 |
+
for frame_i in range(0, n_frames, self.start_frame_interval):
|
92 |
+
sample = dict()
|
93 |
+
sample["video_path"] = video_path
|
94 |
+
sample["orig_num_frames"] = n_frames
|
95 |
+
sample["chunk_index"] = -1
|
96 |
+
sample["frame_ids"] = []
|
97 |
+
curr_frame_i = frame_i
|
98 |
+
while True:
|
99 |
+
if curr_frame_i > (n_frames - 1):
|
100 |
+
break
|
101 |
+
sample["frame_ids"].append(curr_frame_i)
|
102 |
+
if len(sample["frame_ids"]) == self.sequence_length:
|
103 |
+
break
|
104 |
+
curr_frame_i += self.sequence_interval
|
105 |
+
# make sure there are sequence_length number of frames
|
106 |
+
if len(sample["frame_ids"]) == self.sequence_length:
|
107 |
+
sample["chunk_index"] += 1
|
108 |
+
samples.append(sample)
|
109 |
+
return samples
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
return len(self.samples)
|
113 |
+
|
114 |
+
def _load_video(self, video_path, frame_ids):
|
115 |
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
|
116 |
+
assert (np.array(frame_ids) < len(vr)).all(), "Some frame_ids are out of range."
|
117 |
+
assert (np.array(frame_ids) >= 0).all(), "Some frame_ids are negative."
|
118 |
+
vr.seek(0)
|
119 |
+
frame_data = vr.get_batch(frame_ids).asnumpy()
|
120 |
+
fps = vr.get_avg_fps()
|
121 |
+
return frame_data, fps
|
122 |
+
|
123 |
+
def _get_frames(self, video_path, frame_ids):
|
124 |
+
frames, fps = self._load_video(video_path, frame_ids)
|
125 |
+
frames = frames.astype(np.uint8)
|
126 |
+
frames = torch.from_numpy(frames)
|
127 |
+
frames = frames.permute(0, 3, 1, 2) # Rearrange from [T, H, W, C] to [T, C, H, W]
|
128 |
+
return frames, fps
|
129 |
+
|
130 |
+
def __getitem__(self, index):
|
131 |
+
try:
|
132 |
+
sample = self.samples[index]
|
133 |
+
video_path = sample["video_path"]
|
134 |
+
frame_ids = sample["frame_ids"]
|
135 |
+
|
136 |
+
data = dict()
|
137 |
+
|
138 |
+
video, fps = self._get_frames(video_path, frame_ids)
|
139 |
+
data["video"] = video
|
140 |
+
data["fps"] = fps
|
141 |
+
data["num_frames"] = self.sequence_length
|
142 |
+
data["orig_num_frames"] = sample["orig_num_frames"]
|
143 |
+
data["chunk_index"] = sample["chunk_index"]
|
144 |
+
data["frame_start"] = frame_ids[0]
|
145 |
+
data["frame_end"] = frame_ids[-1]
|
146 |
+
|
147 |
+
data["video_name"] = {
|
148 |
+
"video_path": video_path,
|
149 |
+
"start_frame_id": str(frame_ids[0]),
|
150 |
+
}
|
151 |
+
|
152 |
+
# resize video to smallest side aspect preserving
|
153 |
+
data = self.resize_transform(data)
|
154 |
+
# center crop video
|
155 |
+
data = self.crop_transform(data)
|
156 |
+
# normalize video
|
157 |
+
data = self.normalize_transform(data)
|
158 |
+
|
159 |
+
data["video"] = data["video"].permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W]
|
160 |
+
|
161 |
+
return data
|
162 |
+
except Exception:
|
163 |
+
warnings.warn(
|
164 |
+
f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped "
|
165 |
+
f"(by randomly sampling another sample in the same dataset)."
|
166 |
+
)
|
167 |
+
warnings.warn("FULL TRACEBACK:")
|
168 |
+
warnings.warn(traceback.format_exc())
|
169 |
+
self.wrong_number += 1
|
170 |
+
print(self.wrong_number)
|
171 |
+
return self[np.random.randint(len(self.samples))]
|
172 |
+
|
173 |
+
|
174 |
+
if __name__ == "__main__":
|
175 |
+
config = VideoDatasetConfig(dataset_dir="datasets/cosmos_nemo_assets/videos/")
|
176 |
+
dataset = VideoDataset(config)
|
177 |
+
|
178 |
+
indices = [0, 1, 2, -1]
|
179 |
+
for idx in indices:
|
180 |
+
data = dataset[idx]
|
181 |
+
print(
|
182 |
+
(
|
183 |
+
f"{idx=} "
|
184 |
+
f"{data['video'].sum()=}\n"
|
185 |
+
f"{data['video'].shape=}\n"
|
186 |
+
f"{data['video_name']=}\n"
|
187 |
+
f"{data.keys()=}\n"
|
188 |
+
"---"
|
189 |
+
)
|
190 |
+
)
|
cosmos_predict1/autoregressive/diffusion_decoder/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_predict1/autoregressive/diffusion_decoder/config/base/conditioner.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Dict, Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from cosmos_predict1.diffusion.conditioner import BaseVideoCondition, GeneralConditioner
|
22 |
+
from cosmos_predict1.diffusion.config.base.conditioner import (
|
23 |
+
FPSConfig,
|
24 |
+
ImageSizeConfig,
|
25 |
+
LatentConditionConfig,
|
26 |
+
LatentConditionSigmaConfig,
|
27 |
+
NumFramesConfig,
|
28 |
+
PaddingMaskConfig,
|
29 |
+
TextConfig,
|
30 |
+
)
|
31 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
32 |
+
from cosmos_predict1.utils.lazy_config import LazyDict
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class VideoLatentDiffusionDecoderCondition(BaseVideoCondition):
|
37 |
+
# latent_condition will concat to the input of network, along channel dim;
|
38 |
+
# cfg will make latent_condition all zero padding.
|
39 |
+
latent_condition: Optional[torch.Tensor] = None
|
40 |
+
latent_condition_sigma: Optional[torch.Tensor] = None
|
41 |
+
|
42 |
+
|
43 |
+
class VideoDiffusionDecoderConditioner(GeneralConditioner):
|
44 |
+
def forward(
|
45 |
+
self,
|
46 |
+
batch: Dict,
|
47 |
+
override_dropout_rate: Optional[Dict[str, float]] = None,
|
48 |
+
) -> VideoLatentDiffusionDecoderCondition:
|
49 |
+
output = super()._forward(batch, override_dropout_rate)
|
50 |
+
return VideoLatentDiffusionDecoderCondition(**output)
|
51 |
+
|
52 |
+
|
53 |
+
VideoLatentDiffusionDecoderConditionerConfig: LazyDict = L(VideoDiffusionDecoderConditioner)(
|
54 |
+
text=TextConfig(),
|
55 |
+
fps=FPSConfig(),
|
56 |
+
num_frames=NumFramesConfig(),
|
57 |
+
image_size=ImageSizeConfig(),
|
58 |
+
padding_mask=PaddingMaskConfig(),
|
59 |
+
latent_condition=LatentConditionConfig(),
|
60 |
+
latent_condition_sigma=LatentConditionSigmaConfig(),
|
61 |
+
)
|
cosmos_predict1/autoregressive/diffusion_decoder/config/config_latent_diffusion_decoder.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, List
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from cosmos_predict1.autoregressive.diffusion_decoder.config.registry import register_configs as register_dd_configs
|
21 |
+
from cosmos_predict1.diffusion.config.base.model import LatentDiffusionDecoderModelConfig
|
22 |
+
from cosmos_predict1.diffusion.config.registry import register_configs
|
23 |
+
from cosmos_predict1.utils import config
|
24 |
+
from cosmos_predict1.utils.config_helper import import_all_modules_from_package
|
25 |
+
|
26 |
+
|
27 |
+
@attrs.define(slots=False)
|
28 |
+
class Config(config.Config):
|
29 |
+
# default config groups that will be used unless overwritten
|
30 |
+
# see config groups in registry.py
|
31 |
+
defaults: List[Any] = attrs.field(
|
32 |
+
factory=lambda: [
|
33 |
+
"_self_",
|
34 |
+
{"net": None},
|
35 |
+
{"conditioner": "basic"},
|
36 |
+
{"tokenizer": "tokenizer"},
|
37 |
+
{"tokenizer_corruptor": None},
|
38 |
+
{"latent_corruptor": None},
|
39 |
+
{"pixel_corruptor": None},
|
40 |
+
{"experiment": None},
|
41 |
+
]
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def make_config():
|
46 |
+
c = Config(model=LatentDiffusionDecoderModelConfig())
|
47 |
+
|
48 |
+
# Specifying values through instances of attrs
|
49 |
+
c.job.project = "cosmos_video4"
|
50 |
+
c.job.group = "debug"
|
51 |
+
c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}"
|
52 |
+
|
53 |
+
# Call this function to register config groups for advanced overriding.
|
54 |
+
register_configs()
|
55 |
+
register_dd_configs()
|
56 |
+
|
57 |
+
# experiment config are defined in the experiment folder
|
58 |
+
# call import_all_modules_from_package to register them
|
59 |
+
import_all_modules_from_package("cosmos_predict1.diffusion.config.inference", reload=True)
|
60 |
+
import_all_modules_from_package("cosmos_predict1.autoregressive.diffusion_decoder.config.inference", reload=True)
|
61 |
+
return c
|
cosmos_predict1/autoregressive/diffusion_decoder/config/inference/cosmos_diffusiondecoder_7b.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from hydra.core.config_store import ConfigStore
|
17 |
+
|
18 |
+
from cosmos_predict1.autoregressive.diffusion_decoder.network import DiffusionDecoderGeneralDIT
|
19 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
20 |
+
from cosmos_predict1.utils.lazy_config import LazyDict
|
21 |
+
|
22 |
+
num_frames = 57
|
23 |
+
Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY: LazyDict = LazyDict(
|
24 |
+
dict(
|
25 |
+
defaults=[
|
26 |
+
{"override /net": "faditv2_7b"},
|
27 |
+
{"override /tokenizer": "cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624"},
|
28 |
+
{"override /conditioner": "video_latent_diffusion_decoder_cond"},
|
29 |
+
{"override /tokenizer_corruptor": "cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224"},
|
30 |
+
"_self_",
|
31 |
+
],
|
32 |
+
job=dict(
|
33 |
+
group="diffusion_deocder_FT_7Bv1_001",
|
34 |
+
name="DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token",
|
35 |
+
),
|
36 |
+
model=dict(
|
37 |
+
diffusion_decoder_cond_sigma_low=0.0,
|
38 |
+
diffusion_decoder_cond_sigma_high=0.0,
|
39 |
+
diffusion_decoder_corrupt_prob=0.0,
|
40 |
+
condition_on_tokenizer_corruptor_token=True,
|
41 |
+
latent_shape=[
|
42 |
+
16,
|
43 |
+
num_frames,
|
44 |
+
88,
|
45 |
+
160,
|
46 |
+
],
|
47 |
+
tokenizer_corruptor=dict(
|
48 |
+
pixel_chunk_duration=num_frames,
|
49 |
+
latent_chunk_duration=1 + (num_frames - 1) // 8,
|
50 |
+
),
|
51 |
+
net=L(DiffusionDecoderGeneralDIT)(
|
52 |
+
diffusion_decoder_condition_on_sigma=False,
|
53 |
+
max_img_h=240,
|
54 |
+
max_img_w=240,
|
55 |
+
rope_h_extrapolation_ratio=1.5,
|
56 |
+
rope_w_extrapolation_ratio=1.5,
|
57 |
+
rope_t_extrapolation_ratio=1,
|
58 |
+
block_x_format="THWBD",
|
59 |
+
is_diffusion_decoder=True,
|
60 |
+
patch_spatial=2,
|
61 |
+
diffusion_decoder_condition_on_token=True,
|
62 |
+
diffusion_decoder_token_condition_voc_size=64000,
|
63 |
+
diffusion_decoder_token_condition_dim=32,
|
64 |
+
),
|
65 |
+
tokenizer=dict(
|
66 |
+
video_vae=dict(
|
67 |
+
pixel_chunk_duration=num_frames,
|
68 |
+
)
|
69 |
+
),
|
70 |
+
conditioner=dict(
|
71 |
+
latent_condition=dict(
|
72 |
+
dropout_rate=0.2,
|
73 |
+
)
|
74 |
+
),
|
75 |
+
),
|
76 |
+
)
|
77 |
+
)
|
78 |
+
|
79 |
+
cs = ConfigStore.instance()
|
80 |
+
cs.store(
|
81 |
+
group="experiment",
|
82 |
+
package="_global_",
|
83 |
+
name=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY["job"]["name"],
|
84 |
+
node=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY,
|
85 |
+
)
|
cosmos_predict1/autoregressive/diffusion_decoder/config/registry.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from hydra.core.config_store import ConfigStore
|
17 |
+
|
18 |
+
from cosmos_predict1.autoregressive.diffusion_decoder.config.base.conditioner import (
|
19 |
+
VideoLatentDiffusionDecoderConditionerConfig,
|
20 |
+
)
|
21 |
+
from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQJITTokenizer
|
22 |
+
from cosmos_predict1.diffusion.module.pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer
|
23 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
24 |
+
|
25 |
+
|
26 |
+
def get_cosmos_video_discrete_tokenizer_comp8x16x16(
|
27 |
+
resolution: str,
|
28 |
+
chunk_duration: int,
|
29 |
+
checkpoint_path: str,
|
30 |
+
):
|
31 |
+
assert resolution in ["720"]
|
32 |
+
|
33 |
+
pixel_chunk_duration = chunk_duration
|
34 |
+
temporal_compression_factor = 8
|
35 |
+
spatial_compression_factor = 16
|
36 |
+
|
37 |
+
return L(DiscreteVideoFSQJITTokenizer)(
|
38 |
+
enc_fp=checkpoint_path.replace(".jit", "encoder.jit"),
|
39 |
+
dec_fp=checkpoint_path.replace(".jit", "decoder.jit"),
|
40 |
+
name="discrete_video_fsq",
|
41 |
+
latent_ch=6,
|
42 |
+
is_bf16=True,
|
43 |
+
pixel_chunk_duration=pixel_chunk_duration,
|
44 |
+
latent_chunk_duration=1 + (pixel_chunk_duration - 1) // temporal_compression_factor,
|
45 |
+
max_enc_batch_size=8,
|
46 |
+
max_dec_batch_size=4,
|
47 |
+
levels=[8, 8, 8, 5, 5, 5],
|
48 |
+
compression_ratio=[temporal_compression_factor, spatial_compression_factor, spatial_compression_factor],
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
def get_cosmos_video_tokenizer_comp8x8x8(resolution: str, chunk_duration: int, checkpoint_path=None):
|
53 |
+
pixel_chunk_duration = chunk_duration
|
54 |
+
temporal_compression_factor = 8
|
55 |
+
spatial_compression_factor = 8
|
56 |
+
|
57 |
+
return L(JointImageVideoSharedJITTokenizer)(
|
58 |
+
video_vae=L(VideoJITTokenizer)(
|
59 |
+
name="cosmos_predict1_tokenizer",
|
60 |
+
latent_ch=16,
|
61 |
+
is_bf16=True,
|
62 |
+
pixel_chunk_duration=pixel_chunk_duration,
|
63 |
+
temporal_compression_factor=temporal_compression_factor,
|
64 |
+
spatial_compression_factor=spatial_compression_factor,
|
65 |
+
spatial_resolution=resolution,
|
66 |
+
),
|
67 |
+
image_vae=L(JITVAE)(
|
68 |
+
name="cosmos_predict1_tokenizer",
|
69 |
+
latent_ch=16,
|
70 |
+
is_image=False,
|
71 |
+
is_bf16=True,
|
72 |
+
),
|
73 |
+
name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624",
|
74 |
+
latent_ch=16,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def register_tokenizer(cs):
|
79 |
+
cs.store(
|
80 |
+
group="tokenizer",
|
81 |
+
package="model.tokenizer",
|
82 |
+
name="cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624",
|
83 |
+
node=get_cosmos_video_tokenizer_comp8x8x8(
|
84 |
+
resolution="720",
|
85 |
+
chunk_duration=121,
|
86 |
+
checkpoint_path="checkpoints/Cosmos-Tokenize1-CV8x8x8-720p/.jit",
|
87 |
+
),
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def register_corruptor(cs):
|
92 |
+
cs.store(
|
93 |
+
group="tokenizer_corruptor",
|
94 |
+
package="model.tokenizer_corruptor",
|
95 |
+
name="cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224",
|
96 |
+
node=get_cosmos_video_discrete_tokenizer_comp8x16x16(
|
97 |
+
resolution="720",
|
98 |
+
chunk_duration=49,
|
99 |
+
checkpoint_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/.jit",
|
100 |
+
),
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
def register_conditioner(cs):
|
105 |
+
cs.store(
|
106 |
+
group="conditioner",
|
107 |
+
package="model.conditioner",
|
108 |
+
name="video_latent_diffusion_decoder_cond",
|
109 |
+
node=VideoLatentDiffusionDecoderConditionerConfig,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def register_configs():
|
114 |
+
cs = ConfigStore.instance()
|
115 |
+
|
116 |
+
register_conditioner(cs)
|
117 |
+
register_corruptor(cs)
|
118 |
+
register_tokenizer(cs)
|