Commit
·
226c7c9
1
Parent(s):
e22a639
add cosmos-tranfer1/ into repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +203 -0
- app.py +11 -2
- cosmos_transfer1/auxiliary/depth_anything/inference/__init__.py +0 -0
- cosmos_transfer1/auxiliary/depth_anything/inference/depth_anything_pipeline.py +55 -0
- cosmos_transfer1/auxiliary/depth_anything/model/__init__.py +0 -0
- cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py +151 -0
- cosmos_transfer1/auxiliary/guardrail/README.md +17 -0
- cosmos_transfer1/auxiliary/guardrail/__init__.py +14 -0
- cosmos_transfer1/auxiliary/guardrail/aegis/__init__.py +14 -0
- cosmos_transfer1/auxiliary/guardrail/aegis/aegis.py +135 -0
- cosmos_transfer1/auxiliary/guardrail/aegis/categories.py +192 -0
- cosmos_transfer1/auxiliary/guardrail/blocklist/__init__.py +14 -0
- cosmos_transfer1/auxiliary/guardrail/blocklist/blocklist.py +216 -0
- cosmos_transfer1/auxiliary/guardrail/blocklist/utils.py +45 -0
- cosmos_transfer1/auxiliary/guardrail/common/__init__.py +0 -0
- cosmos_transfer1/auxiliary/guardrail/common/core.py +71 -0
- cosmos_transfer1/auxiliary/guardrail/common/io_utils.py +78 -0
- cosmos_transfer1/auxiliary/guardrail/common/presets.py +75 -0
- cosmos_transfer1/auxiliary/guardrail/face_blur_filter/__init__.py +14 -0
- cosmos_transfer1/auxiliary/guardrail/face_blur_filter/blur_utils.py +35 -0
- cosmos_transfer1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py +225 -0
- cosmos_transfer1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py +117 -0
- cosmos_transfer1/auxiliary/guardrail/llamaGuard3/__init__.py +14 -0
- cosmos_transfer1/auxiliary/guardrail/llamaGuard3/categories.py +31 -0
- cosmos_transfer1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py +122 -0
- cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/__init__.py +14 -0
- cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py +60 -0
- cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py +185 -0
- cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py +46 -0
- cosmos_transfer1/auxiliary/human_keypoint/human_keypoint.py +155 -0
- cosmos_transfer1/auxiliary/robot_augmentation/README.md +112 -0
- cosmos_transfer1/auxiliary/robot_augmentation/spatial_temporal_weight.py +577 -0
- cosmos_transfer1/auxiliary/sam2/sam2_model.py +392 -0
- cosmos_transfer1/auxiliary/sam2/sam2_pipeline.py +126 -0
- cosmos_transfer1/auxiliary/sam2/sam2_utils.py +168 -0
- cosmos_transfer1/auxiliary/tokenizer/inference/__init__.py +14 -0
- cosmos_transfer1/auxiliary/tokenizer/inference/image_cli.py +188 -0
- cosmos_transfer1/auxiliary/tokenizer/inference/image_lib.py +124 -0
- cosmos_transfer1/auxiliary/tokenizer/inference/utils.py +402 -0
- cosmos_transfer1/auxiliary/tokenizer/inference/video_cli.py +210 -0
- cosmos_transfer1/auxiliary/tokenizer/inference/video_lib.py +146 -0
- cosmos_transfer1/auxiliary/tokenizer/modules/__init__.py +61 -0
- cosmos_transfer1/auxiliary/tokenizer/modules/distributions.py +42 -0
- cosmos_transfer1/auxiliary/tokenizer/modules/layers2d.py +329 -0
- cosmos_transfer1/auxiliary/tokenizer/modules/layers3d.py +969 -0
- cosmos_transfer1/auxiliary/tokenizer/modules/patching.py +311 -0
- cosmos_transfer1/auxiliary/tokenizer/modules/quantizers.py +513 -0
- cosmos_transfer1/auxiliary/tokenizer/modules/utils.py +116 -0
- cosmos_transfer1/auxiliary/tokenizer/networks/__init__.py +39 -0
- cosmos_transfer1/auxiliary/tokenizer/networks/configs.py +147 -0
.gitignore
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[codz]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py.cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
#poetry.toml
|
110 |
+
|
111 |
+
# pdm
|
112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
115 |
+
#pdm.lock
|
116 |
+
#pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# pixi
|
121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
122 |
+
#pixi.lock
|
123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
125 |
+
.pixi
|
126 |
+
|
127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
128 |
+
__pypackages__/
|
129 |
+
|
130 |
+
# Celery stuff
|
131 |
+
celerybeat-schedule
|
132 |
+
celerybeat.pid
|
133 |
+
|
134 |
+
# SageMath parsed files
|
135 |
+
*.sage.py
|
136 |
+
|
137 |
+
# Environments
|
138 |
+
.env
|
139 |
+
.envrc
|
140 |
+
.venv
|
141 |
+
env/
|
142 |
+
venv/
|
143 |
+
ENV/
|
144 |
+
env.bak/
|
145 |
+
venv.bak/
|
146 |
+
|
147 |
+
# Spyder project settings
|
148 |
+
.spyderproject
|
149 |
+
.spyproject
|
150 |
+
|
151 |
+
# Rope project settings
|
152 |
+
.ropeproject
|
153 |
+
|
154 |
+
# mkdocs documentation
|
155 |
+
/site
|
156 |
+
|
157 |
+
# mypy
|
158 |
+
.mypy_cache/
|
159 |
+
.dmypy.json
|
160 |
+
dmypy.json
|
161 |
+
|
162 |
+
# Pyre type checker
|
163 |
+
.pyre/
|
164 |
+
|
165 |
+
# pytype static type analyzer
|
166 |
+
.pytype/
|
167 |
+
|
168 |
+
# Cython debug symbols
|
169 |
+
cython_debug/
|
170 |
+
|
171 |
+
# PyCharm
|
172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
176 |
+
#.idea/
|
177 |
+
|
178 |
+
# Abstra
|
179 |
+
# Abstra is an AI-powered process automation framework.
|
180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
181 |
+
# Learn more at https://abstra.io/docs
|
182 |
+
.abstra/
|
183 |
+
|
184 |
+
# Visual Studio Code
|
185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
189 |
+
# .vscode/
|
190 |
+
|
191 |
+
# Ruff stuff:
|
192 |
+
.ruff_cache/
|
193 |
+
|
194 |
+
# PyPI configuration file
|
195 |
+
.pypirc
|
196 |
+
|
197 |
+
# Marimo
|
198 |
+
marimo/_static/
|
199 |
+
marimo/_lsp/
|
200 |
+
__marimo__/
|
201 |
+
|
202 |
+
# Streamlit
|
203 |
+
.streamlit/secrets.toml
|
app.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import os
|
|
|
|
|
2 |
from typing import List, Tuple
|
3 |
|
4 |
import gradio as gr
|
@@ -33,14 +35,16 @@ download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
|
|
33 |
from test_environment import main as check_environment
|
34 |
from test_environment import setup_environment
|
35 |
|
36 |
-
setup_environment()
|
37 |
|
38 |
# setup env
|
39 |
os.environ["CUDA_HOME"] = "/usr/local/cuda"
|
40 |
os.environ["LD_LIBRARY_PATH"] = "$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
|
41 |
os.environ["PATH"] = "$CUDA_HOME/bin:/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:$PATH"
|
42 |
|
43 |
-
check_environment()
|
|
|
|
|
|
|
44 |
|
45 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning
|
46 |
|
@@ -279,6 +283,9 @@ def generate_video(
|
|
279 |
else:
|
280 |
actual_seed = seed
|
281 |
|
|
|
|
|
|
|
282 |
args, control_inputs = parse_arguments(
|
283 |
controlnet_specs_in={
|
284 |
"hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
|
@@ -294,6 +301,8 @@ def generate_video(
|
|
294 |
seed=seed,
|
295 |
)
|
296 |
videos, prompts = inference(args, control_inputs)
|
|
|
|
|
297 |
|
298 |
video = videos[0]
|
299 |
return video, video, actual_seed
|
|
|
1 |
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
from typing import List, Tuple
|
5 |
|
6 |
import gradio as gr
|
|
|
35 |
from test_environment import main as check_environment
|
36 |
from test_environment import setup_environment
|
37 |
|
|
|
38 |
|
39 |
# setup env
|
40 |
os.environ["CUDA_HOME"] = "/usr/local/cuda"
|
41 |
os.environ["LD_LIBRARY_PATH"] = "$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
|
42 |
os.environ["PATH"] = "$CUDA_HOME/bin:/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:$PATH"
|
43 |
|
44 |
+
if not check_environment():
|
45 |
+
setup_environment()
|
46 |
+
if not check_environment():
|
47 |
+
sys.exit(1)
|
48 |
|
49 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning
|
50 |
|
|
|
283 |
else:
|
284 |
actual_seed = seed
|
285 |
|
286 |
+
log.info(f"actual_seed: {actual_seed}")
|
287 |
+
|
288 |
+
start_time = time.time()
|
289 |
args, control_inputs = parse_arguments(
|
290 |
controlnet_specs_in={
|
291 |
"hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
|
|
|
301 |
seed=seed,
|
302 |
)
|
303 |
videos, prompts = inference(args, control_inputs)
|
304 |
+
end_time = time.time()
|
305 |
+
log.info(f"Time taken: {end_time - start_time} s")
|
306 |
|
307 |
video = videos[0]
|
308 |
return video, video, actual_seed
|
cosmos_transfer1/auxiliary/depth_anything/inference/__init__.py
ADDED
File without changes
|
cosmos_transfer1/auxiliary/depth_anything/inference/depth_anything_pipeline.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
from cosmos_transfer1.auxiliary.depth_anything.model.depth_anything import DepthAnythingModel
|
21 |
+
|
22 |
+
|
23 |
+
def parse_args():
|
24 |
+
parser = argparse.ArgumentParser(description="Depth Estimation using Depth Anything V2")
|
25 |
+
parser.add_argument("--input", type=str, required=True, help="Path to input image or video file")
|
26 |
+
parser.add_argument("--output", type=str, required=True, help="Path to save the output image or video")
|
27 |
+
parser.add_argument(
|
28 |
+
"--mode",
|
29 |
+
type=str,
|
30 |
+
choices=["image", "video"],
|
31 |
+
default="image",
|
32 |
+
help="Processing mode: 'image' for a single image, 'video' for a video file",
|
33 |
+
)
|
34 |
+
return parser.parse_args()
|
35 |
+
|
36 |
+
|
37 |
+
def main():
|
38 |
+
args = parse_args()
|
39 |
+
model = DepthAnythingModel()
|
40 |
+
|
41 |
+
if args.mode == "image":
|
42 |
+
# Load the input image and predict its depth
|
43 |
+
image = Image.open(args.input).convert("RGB")
|
44 |
+
depth_image = model.predict_depth(image)
|
45 |
+
depth_image.save(args.output)
|
46 |
+
print(f"Depth image saved to {args.output}")
|
47 |
+
elif args.mode == "video":
|
48 |
+
# Process the video and save the output
|
49 |
+
out_path = model.predict_depth_video(args.input, args.output)
|
50 |
+
if out_path:
|
51 |
+
print(f"Depth video saved to {out_path}")
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
main()
|
cosmos_transfer1/auxiliary/depth_anything/model/__init__.py
ADDED
File without changes
|
cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import imageio
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
from PIL import Image
|
23 |
+
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
|
24 |
+
|
25 |
+
from cosmos_transfer1.checkpoints import DEPTH_ANYTHING_MODEL_CHECKPOINT
|
26 |
+
from cosmos_transfer1.utils import log
|
27 |
+
|
28 |
+
|
29 |
+
class DepthAnythingModel:
|
30 |
+
def __init__(self):
|
31 |
+
"""
|
32 |
+
Initialize the Depth Anything model and its image processor.
|
33 |
+
"""
|
34 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
# Load image processor and model with half precision
|
36 |
+
print(f"Loading Depth Anything model - {DEPTH_ANYTHING_MODEL_CHECKPOINT}...")
|
37 |
+
self.image_processor = AutoImageProcessor.from_pretrained(
|
38 |
+
DEPTH_ANYTHING_MODEL_CHECKPOINT,
|
39 |
+
torch_dtype=torch.float16,
|
40 |
+
trust_remote_code=True,
|
41 |
+
)
|
42 |
+
self.model = AutoModelForDepthEstimation.from_pretrained(
|
43 |
+
DEPTH_ANYTHING_MODEL_CHECKPOINT,
|
44 |
+
torch_dtype=torch.float16,
|
45 |
+
trust_remote_code=True,
|
46 |
+
).to(self.device)
|
47 |
+
|
48 |
+
def predict_depth(self, image: Image.Image) -> Image.Image:
|
49 |
+
"""
|
50 |
+
Process a single PIL image and return a depth map as a uint16 PIL Image.
|
51 |
+
"""
|
52 |
+
# Prepare inputs for the model
|
53 |
+
inputs = self.image_processor(images=image, return_tensors="pt")
|
54 |
+
# Move all tensors to the proper device with half precision
|
55 |
+
inputs = {k: v.to(self.device, dtype=torch.float16) for k, v in inputs.items()}
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
outputs = self.model(**inputs)
|
59 |
+
predicted_depth = outputs.predicted_depth
|
60 |
+
|
61 |
+
# Interpolate the predicted depth to the original image size
|
62 |
+
prediction = torch.nn.functional.interpolate(
|
63 |
+
predicted_depth.unsqueeze(1),
|
64 |
+
size=image.size[::-1], # PIL image size is (width, height), interpolate expects (height, width)
|
65 |
+
mode="bicubic",
|
66 |
+
align_corners=False,
|
67 |
+
)
|
68 |
+
|
69 |
+
# Convert the output tensor to a numpy array and save as a depth image
|
70 |
+
output = prediction.squeeze().cpu().numpy()
|
71 |
+
depth_image = DepthAnythingModel.save_depth(output)
|
72 |
+
return depth_image
|
73 |
+
|
74 |
+
def __call__(self, input_video: str, output_video: str = "depth.mp4") -> str:
|
75 |
+
"""
|
76 |
+
Process a video file frame-by-frame to produce a depth-estimated video.
|
77 |
+
The output video is saved as an MP4 file.
|
78 |
+
"""
|
79 |
+
|
80 |
+
log.info(f"Processing video: {input_video} to generate depth video: {output_video}")
|
81 |
+
assert os.path.exists(input_video)
|
82 |
+
|
83 |
+
cap = cv2.VideoCapture(input_video)
|
84 |
+
if not cap.isOpened():
|
85 |
+
print("Error: Cannot open video file.")
|
86 |
+
return
|
87 |
+
|
88 |
+
# Retrieve video properties
|
89 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
90 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
91 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
92 |
+
|
93 |
+
depths = []
|
94 |
+
while True:
|
95 |
+
ret, frame = cap.read()
|
96 |
+
if not ret:
|
97 |
+
break
|
98 |
+
|
99 |
+
# Convert frame from BGR to RGB and then to PIL Image
|
100 |
+
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
101 |
+
inputs = self.image_processor(images=image, return_tensors="pt")
|
102 |
+
inputs = {k: v.to(self.device, dtype=torch.float16) for k, v in inputs.items()}
|
103 |
+
|
104 |
+
with torch.no_grad():
|
105 |
+
outputs = self.model(**inputs)
|
106 |
+
predicted_depth = outputs.predicted_depth
|
107 |
+
|
108 |
+
# For video processing, take the first output and interpolate to original size
|
109 |
+
prediction = torch.nn.functional.interpolate(
|
110 |
+
predicted_depth[0].unsqueeze(0).unsqueeze(0),
|
111 |
+
size=(frame_height, frame_width),
|
112 |
+
mode="bicubic",
|
113 |
+
align_corners=False,
|
114 |
+
)
|
115 |
+
depth = prediction.squeeze().cpu().numpy()
|
116 |
+
depths += [depth]
|
117 |
+
cap.release()
|
118 |
+
|
119 |
+
depths = np.stack(depths)
|
120 |
+
depths_normed = (depths - depths.min()) / (depths.max() - depths.min() + 1e-8) * 255.0
|
121 |
+
depths_normed = depths_normed.astype(np.uint8)
|
122 |
+
|
123 |
+
os.makedirs(os.path.dirname(output_video), exist_ok=True)
|
124 |
+
self.write_video(depths_normed, output_video, fps=fps)
|
125 |
+
return output_video
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def save_depth(output: np.ndarray) -> Image.Image:
|
129 |
+
"""
|
130 |
+
Convert the raw depth output (float values) into a uint16 PIL Image.
|
131 |
+
"""
|
132 |
+
depth_min = output.min()
|
133 |
+
depth_max = output.max()
|
134 |
+
max_val = (2**16) - 1 # Maximum value for uint16
|
135 |
+
|
136 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
137 |
+
out_array = max_val * (output - depth_min) / (depth_max - depth_min)
|
138 |
+
else:
|
139 |
+
out_array = np.zeros_like(output)
|
140 |
+
|
141 |
+
formatted = out_array.astype("uint16")
|
142 |
+
depth_image = Image.fromarray(formatted, mode="I;16")
|
143 |
+
return depth_image
|
144 |
+
|
145 |
+
@staticmethod
|
146 |
+
def write_video(frames, output_path, fps=30):
|
147 |
+
with imageio.get_writer(output_path, fps=fps, macro_block_size=8) as writer:
|
148 |
+
for frame in frames:
|
149 |
+
if len(frame.shape) == 2: # single channel
|
150 |
+
frame = frame[:, :, None].repeat(3, axis=2)
|
151 |
+
writer.append_data(frame)
|
cosmos_transfer1/auxiliary/guardrail/README.md
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cosmos Guardrail
|
2 |
+
|
3 |
+
This page outlines a set of tools to ensure content safety in Cosmos. For implementation details, please consult the [Cosmos paper](https://research.nvidia.com/publication/2025-01_cosmos-world-foundation-model-platform-physical-ai).
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
|
7 |
+
Our guardrail system consists of two stages: pre-Guard and post-Guard.
|
8 |
+
|
9 |
+
Cosmos pre-Guard models are applied to text input, including input prompts and upsampled prompts.
|
10 |
+
|
11 |
+
* Blocklist: a keyword list checker for detecting harmful keywords
|
12 |
+
* Llama Guard 3: an LLM-based approach for blocking harmful prompts
|
13 |
+
|
14 |
+
Cosmos post-Guard models are applied to video frames generated by Cosmos models.
|
15 |
+
|
16 |
+
* Video Content Safety Filter: a classifier trained to distinguish between safe and unsafe video frames
|
17 |
+
* Face Blur Filter: a face detection and blurring module
|
cosmos_transfer1/auxiliary/guardrail/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_transfer1/auxiliary/guardrail/aegis/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_transfer1/auxiliary/guardrail/aegis/aegis.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import os
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from peft import PeftModel
|
21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
22 |
+
|
23 |
+
from cosmos_transfer1.auxiliary.guardrail.aegis.categories import UNSAFE_CATEGORIES
|
24 |
+
from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner
|
25 |
+
from cosmos_transfer1.checkpoints import AEGIS_MODEL_CHECKPOINT, LLAMA_GUARD_MODEL_CHECKPOINT
|
26 |
+
from cosmos_transfer1.utils import log, misc
|
27 |
+
|
28 |
+
SAFE = misc.Color.green("SAFE")
|
29 |
+
UNSAFE = misc.Color.red("UNSAFE")
|
30 |
+
|
31 |
+
|
32 |
+
class Aegis(ContentSafetyGuardrail):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
checkpoint_dir: str,
|
36 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
37 |
+
) -> None:
|
38 |
+
self.checkpoint_dir = checkpoint_dir
|
39 |
+
self.device = device
|
40 |
+
self.dtype = torch.bfloat16
|
41 |
+
base_model_id = "meta-llama/LlamaGuard-7b"
|
42 |
+
aegis_adapter = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"
|
43 |
+
|
44 |
+
base_model_dir = os.path.join(self.checkpoint_dir, base_model_id)
|
45 |
+
aegis_adapter_dir = os.path.join(self.checkpoint_dir, aegis_adapter)
|
46 |
+
|
47 |
+
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, cache_dir=base_model_dir)
|
48 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, cache_dir=base_model_dir)
|
49 |
+
self.model = PeftModel.from_pretrained(base_model, aegis_adapter, cache_dir=aegis_adapter_dir)
|
50 |
+
|
51 |
+
self.model.to(self.device, dtype=self.dtype).eval()
|
52 |
+
|
53 |
+
def get_moderation_prompt(self, user_prompt: str) -> str:
|
54 |
+
"""Create the moderation prompt for the Aegis model."""
|
55 |
+
full_prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
|
56 |
+
|
57 |
+
<BEGIN UNSAFE CONTENT CATEGORIES>
|
58 |
+
{"".join([f"{category['instruction']}" for category in UNSAFE_CATEGORIES])}
|
59 |
+
<END UNSAFE CONTENT CATEGORIES>
|
60 |
+
|
61 |
+
<BEGIN CONVERSATION>
|
62 |
+
|
63 |
+
"""
|
64 |
+
full_prompt += f"User: {user_prompt}\n\n"
|
65 |
+
full_prompt += """<END CONVERSATION>
|
66 |
+
|
67 |
+
Provide your safety assessment for User in the above conversation:
|
68 |
+
- First line must read 'safe' or 'unsafe'.
|
69 |
+
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
|
70 |
+
|
71 |
+
return full_prompt
|
72 |
+
|
73 |
+
def get_aegis_block_message(self, moderation_output: str) -> str:
|
74 |
+
"""Extract the blocked category and reason from the Aegis model output."""
|
75 |
+
block_msg = "Prompt blocked by Aegis."
|
76 |
+
try:
|
77 |
+
# Extract blocked category and reason
|
78 |
+
if len(moderation_output.split("\n")) >= 2:
|
79 |
+
moderation_output_list = moderation_output.split("\n")
|
80 |
+
blocked_category = int(moderation_output_list[1].lstrip("O")) - 1
|
81 |
+
if 0 <= blocked_category < len(UNSAFE_CATEGORIES):
|
82 |
+
blocked_reason = UNSAFE_CATEGORIES[blocked_category]["blocked_reason"]
|
83 |
+
blocked_category_name = UNSAFE_CATEGORIES[blocked_category]["category"]
|
84 |
+
block_msg = f"{blocked_category_name}: {blocked_reason}"
|
85 |
+
except Exception as e:
|
86 |
+
log.warning(f"Unable to extract blocked category and reason from Aegis output: {e}")
|
87 |
+
return block_msg
|
88 |
+
|
89 |
+
def filter_aegis_output(self, prompt: str) -> tuple[bool, str]:
|
90 |
+
"""Filter the Aegis model output and return the safety status and message."""
|
91 |
+
full_prompt = self.get_moderation_prompt(prompt)
|
92 |
+
inputs = self.tokenizer([full_prompt], add_special_tokens=False, return_tensors="pt").to(self.device)
|
93 |
+
output = self.model.generate(**inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id)
|
94 |
+
prompt_len = inputs["input_ids"].shape[-1]
|
95 |
+
moderation_output = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
|
96 |
+
|
97 |
+
if "unsafe" in moderation_output.lower():
|
98 |
+
block_msg = self.get_aegis_block_message(moderation_output)
|
99 |
+
return False, block_msg
|
100 |
+
else:
|
101 |
+
return True, ""
|
102 |
+
|
103 |
+
def is_safe(self, prompt: str) -> tuple[bool, str]:
|
104 |
+
"""Check if the input prompt is safe according to the Aegis model."""
|
105 |
+
try:
|
106 |
+
return self.filter_aegis_output(prompt)
|
107 |
+
except Exception as e:
|
108 |
+
log.error(f"Unexpected error occurred when running Aegis guardrail: {e}")
|
109 |
+
return True, "Unexpected error occurred when running Aegis guardrail."
|
110 |
+
|
111 |
+
|
112 |
+
def parse_args():
|
113 |
+
parser = argparse.ArgumentParser()
|
114 |
+
parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
|
115 |
+
parser.add_argument(
|
116 |
+
"--checkpoint_dir",
|
117 |
+
type=str,
|
118 |
+
help="Path to the Aegis checkpoint folder",
|
119 |
+
default=DEFAULT_CHECKPOINT_DIR,
|
120 |
+
)
|
121 |
+
return parser.parse_args()
|
122 |
+
|
123 |
+
|
124 |
+
def main(args):
|
125 |
+
aegis = Aegis(checkpoint_dir=args.checkpoint_dir)
|
126 |
+
runner = GuardrailRunner(safety_models=[aegis])
|
127 |
+
with misc.timer("aegis safety check"):
|
128 |
+
safety, message = runner.run_safety_check(args.prompt)
|
129 |
+
log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
|
130 |
+
log.info(f"Message: {message}") if not safety else None
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
args = parse_args()
|
135 |
+
main(args)
|
cosmos_transfer1/auxiliary/guardrail/aegis/categories.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
VIOLENCE = {
|
17 |
+
"category": "violence",
|
18 |
+
"blocked_reason": """This prompt is blocked due to the presence of violent content.""",
|
19 |
+
"instruction": """O1: Violence.
|
20 |
+
Should not
|
21 |
+
- Help people plan or engage in violence.
|
22 |
+
- Encourage people to engage in violence.
|
23 |
+
- Contain unpleasant vivid violent imagery or violent description that arouses discomfort in the reader.
|
24 |
+
Can
|
25 |
+
- Provide information on violence.
|
26 |
+
- Discuss the topics of violence, and can discuss historical events involving violence.
|
27 |
+
""",
|
28 |
+
}
|
29 |
+
|
30 |
+
SEXUAL = {
|
31 |
+
"category": "sexual",
|
32 |
+
"blocked_reason": """This prompt is blocked due to the presence of sexual content.""",
|
33 |
+
"instruction": """O2: Sexual.
|
34 |
+
Should not
|
35 |
+
- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
|
36 |
+
- Engage or encourage sexual activity that is considered illegal.
|
37 |
+
Can
|
38 |
+
- Discuss topics related to sexuality and sexual education.
|
39 |
+
- Provide general information regarding sexual activities.
|
40 |
+
- Provide guidance on topics related to sexual health.
|
41 |
+
""",
|
42 |
+
}
|
43 |
+
|
44 |
+
CRIMINAL_PLANNING = {
|
45 |
+
"category": "criminal_planning",
|
46 |
+
"blocked_reason": """This prompt is blocked due to the presence of criminal planning/confessions content.""",
|
47 |
+
"instruction": """O3: Criminal Planning/Confessions.
|
48 |
+
Should not
|
49 |
+
- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
|
50 |
+
- Encourage people to engage in criminal activities or condone criminal behavior.
|
51 |
+
- Have a confession of crime that was committed or will be committed in future.
|
52 |
+
Can
|
53 |
+
- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
|
54 |
+
- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.
|
55 |
+
""",
|
56 |
+
}
|
57 |
+
|
58 |
+
WEAPONS = {
|
59 |
+
"category": "weapons",
|
60 |
+
"blocked_reason": """This prompt is blocked due to the presence of weapons content.""",
|
61 |
+
"instruction": """O4: Guns and Illegal Weapons.
|
62 |
+
Should not
|
63 |
+
- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
|
64 |
+
- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
|
65 |
+
- Assist or encourage people to execute or own other illegal weapons for example biological agents, or chemical weapons, including instructions on how to create or use them.
|
66 |
+
Can
|
67 |
+
- Discuss firearms and the arguments for and against firearm ownership.
|
68 |
+
""",
|
69 |
+
}
|
70 |
+
|
71 |
+
SUBSTANCE_ABUSE = {
|
72 |
+
"category": "substance_abuse",
|
73 |
+
"blocked_reason": """This prompt is blocked due to the presence of substance abuse content.""",
|
74 |
+
"instruction": """O5: Controlled/Regulated Substances.
|
75 |
+
Should not
|
76 |
+
- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
|
77 |
+
- Assist or encourage people to create such substances when it is illegal to do so.
|
78 |
+
Can
|
79 |
+
- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
|
80 |
+
- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.
|
81 |
+
""",
|
82 |
+
}
|
83 |
+
|
84 |
+
SUICIDE = {
|
85 |
+
"category": "suicide",
|
86 |
+
"blocked_reason": "This prompt is blocked due to the presence of suicide and self harm content.",
|
87 |
+
"instruction": """O6: Suicide and Self Harm.
|
88 |
+
Should not
|
89 |
+
- Encourage people to harm themselves.
|
90 |
+
- Romanticize or condone self-harm or suicide.
|
91 |
+
- Provide information regarding the methods of suicide or self-harm.
|
92 |
+
- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
|
93 |
+
- Contain confessions or intent of self harm
|
94 |
+
Can
|
95 |
+
- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
|
96 |
+
Should
|
97 |
+
- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.
|
98 |
+
""",
|
99 |
+
}
|
100 |
+
|
101 |
+
CSAM = {
|
102 |
+
"category": "csam",
|
103 |
+
"blocked_reason": "This prompt contains material related to Child Sexual Abuse",
|
104 |
+
"instruction": """O7: Sexual (minor).
|
105 |
+
Should not
|
106 |
+
- Assist, engage or encourage intention to engage in sexual activity with an underage/minor.
|
107 |
+
- Discuss or plan sexual acts, sharing of media of sexual nature, personal information or intention to meet in person for sexual activity with a minor.
|
108 |
+
- Should not include discussion where private information is shared with underage minors for purposes of acting on them and taking actions that are sexual in nature including meeting or sharing media.
|
109 |
+
- Should not discuss child abuse or engage, encourage, assist in sexual grooming/trafficking of children.
|
110 |
+
Can
|
111 |
+
- Provide information about sex education for use in schools.
|
112 |
+
""",
|
113 |
+
}
|
114 |
+
|
115 |
+
HATRED = {
|
116 |
+
"category": "hatred",
|
117 |
+
"blocked_reason": "This prompt is blocked due to the presence of hate speech content.",
|
118 |
+
"instruction": """O8: Hate/Identity Hate.
|
119 |
+
Should not
|
120 |
+
- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their
|
121 |
+
race, color, ethnicity, religion, national origin, sexual orientation, gender, gender identity, or disability or caste.
|
122 |
+
- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
|
123 |
+
- Should not discuss prejudice towards a group classified as protected according to the law of Germany, UK, USA or any other country.
|
124 |
+
- Should not contain slurs.
|
125 |
+
- Should not contain text that is distasteful, offensive, or hateful that causes others to feel grief, anger, or fear.
|
126 |
+
Can
|
127 |
+
- Provide information on discrimination.
|
128 |
+
- Discuss the topics of discrimination, and can discuss historical events involving discrimination.
|
129 |
+
- Discuss neutral statements referring to group identity and contextualized hate speech (that talks about hate but does not incite hate).
|
130 |
+
""",
|
131 |
+
}
|
132 |
+
|
133 |
+
HARASSMENT = {
|
134 |
+
"category": "harassment",
|
135 |
+
"blocked_reason": "This prompt is blocked due to the presence of harassment content.",
|
136 |
+
"instruction": """O10: Harassment.
|
137 |
+
Should not
|
138 |
+
- Have content that may be used to torment or annoy individuals in real life, or make harassment more likely to occur.
|
139 |
+
- Engage in harassment, misconduct and unlawful discrimination
|
140 |
+
- Single out an individual or marginalization and/or retaliation based on protected characteristics like Race, Color, Gender, Sex,
|
141 |
+
Sexual orientation, Gender identity, and gender expression, National origin, Ethnicity, Disability (including being regarded as disabled),
|
142 |
+
Religion, Age (40+), Pregnancy (including pregnancy, childbirth or related medical conditions), Genetic information, Military or veteran status, Citizenship status, Political activity, or affiliation,
|
143 |
+
- Taking or requesting statutorily protected leave Body characteristics.
|
144 |
+
- Promise to give a benefit, or a threat to retaliate or take an adverse action based on a response to harassment and misconduct.
|
145 |
+
- Bullying.
|
146 |
+
- Being derogatory and humiliating toward an individual but not necessarily protected characteristics under law.
|
147 |
+
- Making rude or insulting comments, demeaning, and objectifying terms toward an individual.
|
148 |
+
Can
|
149 |
+
- Discuss negative impacts of harassment and bullying based on protected characteristics like race, color, gender, sex, nationality, within a contextualized setting.
|
150 |
+
- Provide information on protected characteristics or refers, cites material that describes what protected characteristics are.
|
151 |
+
- Provide information or refer or cite material that informs what harassment is.
|
152 |
+
""",
|
153 |
+
}
|
154 |
+
|
155 |
+
THREAT = {
|
156 |
+
"category": "threat",
|
157 |
+
"blocked_reason": "This prompt is blocked due to the presence of threat content.",
|
158 |
+
"instruction": """O11: Threat.
|
159 |
+
Should not
|
160 |
+
- Discuss a communication of intent to inflict harm or loss on another person.
|
161 |
+
- Should not contain discussion of intention of intimidation which is a tactic used between conflicting parties to make the other timid or psychologically insecure for coercion or control.
|
162 |
+
- Should not contain intention of knowingly putting another person in fear of bodily injury.
|
163 |
+
- Should not assist or plan injury, physical or mental damage, act or instance of injury, or material or any kind of loss to a person.
|
164 |
+
- Should not contain threats that are forbidden by law as intent to obtain a monetary advantage or to compel a person to act against their will.
|
165 |
+
Can
|
166 |
+
- Provide information about threat is and the meaning of threat.
|
167 |
+
""",
|
168 |
+
}
|
169 |
+
|
170 |
+
PROFANITY = {
|
171 |
+
"category": "profanity",
|
172 |
+
"blocked_reason": "This prompt is blocked due to the presence of profanity content.",
|
173 |
+
"instruction": """O12: Profanity.
|
174 |
+
Should not
|
175 |
+
- Contain swear words, curse words, or other obscene or profane language
|
176 |
+
- This includes offensive words used without any intention to act on them or without any explicit targeted recipient
|
177 |
+
""",
|
178 |
+
}
|
179 |
+
|
180 |
+
UNSAFE_CATEGORIES = [
|
181 |
+
VIOLENCE,
|
182 |
+
SEXUAL,
|
183 |
+
CRIMINAL_PLANNING,
|
184 |
+
WEAPONS,
|
185 |
+
SUBSTANCE_ABUSE,
|
186 |
+
SUICIDE,
|
187 |
+
CSAM,
|
188 |
+
HATRED,
|
189 |
+
HARASSMENT,
|
190 |
+
THREAT,
|
191 |
+
PROFANITY,
|
192 |
+
]
|
cosmos_transfer1/auxiliary/guardrail/blocklist/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_transfer1/auxiliary/guardrail/blocklist/blocklist.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
import string
|
20 |
+
from difflib import SequenceMatcher
|
21 |
+
|
22 |
+
import nltk
|
23 |
+
from better_profanity import profanity
|
24 |
+
|
25 |
+
from cosmos_transfer1.auxiliary.guardrail.blocklist.utils import read_keyword_list_from_dir, to_ascii
|
26 |
+
from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner
|
27 |
+
from cosmos_transfer1.utils import log, misc
|
28 |
+
|
29 |
+
CENSOR = misc.Color.red("*")
|
30 |
+
|
31 |
+
|
32 |
+
class Blocklist(ContentSafetyGuardrail):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
checkpoint_dir: str,
|
36 |
+
guardrail_partial_match_min_chars: int = 6,
|
37 |
+
guardrail_partial_match_letter_count: float = 0.4,
|
38 |
+
) -> None:
|
39 |
+
self.checkpoint_dir = os.path.join(checkpoint_dir, "nvidia/Cosmos-Guardrail1/blocklist")
|
40 |
+
nltk.data.path.append(os.path.join(self.checkpoint_dir, "nltk_data"))
|
41 |
+
self.lemmatizer = nltk.WordNetLemmatizer()
|
42 |
+
self.profanity = profanity
|
43 |
+
self.guardrail_partial_match_min_chars = guardrail_partial_match_min_chars
|
44 |
+
self.guardrail_partial_match_letter_count = guardrail_partial_match_letter_count
|
45 |
+
|
46 |
+
# Load blocklist and whitelist keywords
|
47 |
+
self.blocklist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "custom"))
|
48 |
+
self.whitelist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "whitelist"))
|
49 |
+
self.exact_match_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "exact_match"))
|
50 |
+
|
51 |
+
self.profanity.load_censor_words(custom_words=self.blocklist_words, whitelist_words=self.whitelist_words)
|
52 |
+
log.debug(f"Loaded {len(self.blocklist_words)} words/phrases from blocklist")
|
53 |
+
log.debug(f"Whitelisted {len(self.whitelist_words)} words/phrases from whitelist")
|
54 |
+
log.debug(f"Loaded {len(self.exact_match_words)} exact match words/phrases from blocklist")
|
55 |
+
|
56 |
+
def uncensor_whitelist(self, input_prompt: str, censored_prompt: str) -> str:
|
57 |
+
"""Explicitly uncensor words that are in the whitelist."""
|
58 |
+
input_words = input_prompt.split()
|
59 |
+
censored_words = censored_prompt.split()
|
60 |
+
whitelist_words = set(self.whitelist_words)
|
61 |
+
for i, token in enumerate(input_words):
|
62 |
+
if token.strip(string.punctuation).lower() in whitelist_words:
|
63 |
+
censored_words[i] = token
|
64 |
+
censored_prompt = " ".join(censored_words)
|
65 |
+
return censored_prompt
|
66 |
+
|
67 |
+
def censor_prompt(self, input_prompt: str) -> tuple[bool, str]:
|
68 |
+
"""Censor the prompt using the blocklist with better-profanity fuzzy matching.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
input_prompt: input prompt to censor
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
bool: True if the prompt is blocked, False otherwise
|
75 |
+
str: A message indicating why the prompt was blocked
|
76 |
+
"""
|
77 |
+
censored_prompt = self.profanity.censor(input_prompt, censor_char=CENSOR)
|
78 |
+
# Uncensor whitelisted words that were censored from blocklist fuzzy matching
|
79 |
+
censored_prompt = self.uncensor_whitelist(input_prompt, censored_prompt)
|
80 |
+
if CENSOR in censored_prompt:
|
81 |
+
return True, f"Prompt blocked by censorship: Censored Prompt: {censored_prompt}"
|
82 |
+
return False, ""
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def check_partial_match(
|
86 |
+
normalized_prompt: str, normalized_word: str, guardrail_partial_match_letter_count: float
|
87 |
+
) -> tuple[bool, str]:
|
88 |
+
"""
|
89 |
+
Check robustly if normalized word and the matching target have a difference of up to guardrail_partial_match_letter_count characters.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
normalized_prompt: a string with many words
|
93 |
+
normalized_word: a string with one or multiple words, its length is smaller than normalized_prompt
|
94 |
+
guardrail_partial_match_letter_count: maximum allowed difference in characters (float to allow partial characters)
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
bool: True if a match is found, False otherwise
|
98 |
+
str: A message indicating why the prompt was blocked
|
99 |
+
"""
|
100 |
+
prompt_words = normalized_prompt.split()
|
101 |
+
word_length = len(normalized_word.split())
|
102 |
+
max_similarity_ratio = (len(normalized_word) - float(guardrail_partial_match_letter_count)) / float(
|
103 |
+
len(normalized_word)
|
104 |
+
)
|
105 |
+
|
106 |
+
for i in range(len(prompt_words) - word_length + 1):
|
107 |
+
# Extract a substring from the prompt with the same number of words as the normalized_word
|
108 |
+
substring = " ".join(prompt_words[i : i + word_length])
|
109 |
+
similarity_ratio = SequenceMatcher(None, substring, normalized_word).ratio()
|
110 |
+
if similarity_ratio >= max_similarity_ratio:
|
111 |
+
return (
|
112 |
+
True,
|
113 |
+
f"Prompt blocked by partial match blocklist: Prompt: {normalized_prompt}, Partial Match Word: {normalized_word}",
|
114 |
+
)
|
115 |
+
|
116 |
+
return False, ""
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def check_against_whole_word_blocklist(
|
120 |
+
prompt: str,
|
121 |
+
blocklist: list[str],
|
122 |
+
guardrail_partial_match_min_chars: int = 6,
|
123 |
+
guardrail_partial_match_letter_count: float = 0.4,
|
124 |
+
) -> bool:
|
125 |
+
"""
|
126 |
+
Check if the prompt contains any whole words from the blocklist.
|
127 |
+
The match is case insensitive and robust to multiple spaces between words.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
prompt: input prompt to check
|
131 |
+
blocklist: list of words to check against
|
132 |
+
guardrail_partial_match_min_chars: minimum number of characters in a word to check for partial match
|
133 |
+
guardrail_partial_match_letter_count: maximum allowed difference in characters for partial match
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
bool: True if a match is found, False otherwise
|
137 |
+
str: A message indicating why the prompt was blocked
|
138 |
+
"""
|
139 |
+
# Normalize spaces and convert to lowercase
|
140 |
+
normalized_prompt = re.sub(r"\s+", " ", prompt).strip().lower()
|
141 |
+
|
142 |
+
for word in blocklist:
|
143 |
+
# Normalize spaces and convert to lowercase for each blocklist word
|
144 |
+
normalized_word = re.sub(r"\s+", " ", word).strip().lower()
|
145 |
+
|
146 |
+
# Use word boundaries to ensure whole word match
|
147 |
+
if re.search(r"\b" + re.escape(normalized_word) + r"\b", normalized_prompt):
|
148 |
+
return True, f"Prompt blocked by exact match blocklist: Prompt: {prompt}, Exact Match Word: {word}"
|
149 |
+
|
150 |
+
# Check for partial match if the word is long enough
|
151 |
+
if len(normalized_word) >= guardrail_partial_match_min_chars:
|
152 |
+
match, message = Blocklist.check_partial_match(
|
153 |
+
normalized_prompt, normalized_word, guardrail_partial_match_letter_count
|
154 |
+
)
|
155 |
+
if match:
|
156 |
+
return True, message
|
157 |
+
|
158 |
+
return False, ""
|
159 |
+
|
160 |
+
def is_safe(self, input_prompt: str = "") -> tuple[bool, str]:
|
161 |
+
"""Check if the input prompt is safe using the blocklist."""
|
162 |
+
# Check if the input is empty
|
163 |
+
if not input_prompt:
|
164 |
+
return False, "Input is empty"
|
165 |
+
input_prompt = to_ascii(input_prompt)
|
166 |
+
|
167 |
+
# Check full sentence for censored words
|
168 |
+
censored, message = self.censor_prompt(input_prompt)
|
169 |
+
if censored:
|
170 |
+
return False, message
|
171 |
+
|
172 |
+
# Check lemmatized words for censored words
|
173 |
+
tokens = nltk.word_tokenize(input_prompt)
|
174 |
+
lemmas = [self.lemmatizer.lemmatize(token) for token in tokens]
|
175 |
+
lemmatized_prompt = " ".join(lemmas)
|
176 |
+
censored, message = self.censor_prompt(lemmatized_prompt)
|
177 |
+
if censored:
|
178 |
+
return False, message
|
179 |
+
|
180 |
+
# Check for exact match blocklist words
|
181 |
+
censored, message = self.check_against_whole_word_blocklist(
|
182 |
+
input_prompt,
|
183 |
+
self.exact_match_words,
|
184 |
+
self.guardrail_partial_match_min_chars,
|
185 |
+
self.guardrail_partial_match_letter_count,
|
186 |
+
)
|
187 |
+
if censored:
|
188 |
+
return False, message
|
189 |
+
|
190 |
+
# If all these checks pass, the input is safe
|
191 |
+
return True, "Input is safe"
|
192 |
+
|
193 |
+
|
194 |
+
def parse_args():
|
195 |
+
parser = argparse.ArgumentParser()
|
196 |
+
parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
|
197 |
+
parser.add_argument(
|
198 |
+
"--checkpoint_dir",
|
199 |
+
type=str,
|
200 |
+
help="Path to the Blocklist checkpoint folder",
|
201 |
+
)
|
202 |
+
return parser.parse_args()
|
203 |
+
|
204 |
+
|
205 |
+
def main(args):
|
206 |
+
blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir)
|
207 |
+
runner = GuardrailRunner(safety_models=[blocklist])
|
208 |
+
with misc.timer("blocklist safety check"):
|
209 |
+
safety, message = runner.run_safety_check(args.prompt)
|
210 |
+
log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
|
211 |
+
log.info(f"Message: {message}") if not safety else None
|
212 |
+
|
213 |
+
|
214 |
+
if __name__ == "__main__":
|
215 |
+
args = parse_args()
|
216 |
+
main(args)
|
cosmos_transfer1/auxiliary/guardrail/blocklist/utils.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import re
|
18 |
+
|
19 |
+
from cosmos_transfer1.utils import log
|
20 |
+
|
21 |
+
|
22 |
+
def read_keyword_list_from_dir(folder_path: str) -> list[str]:
|
23 |
+
"""Read keyword list from all files in a folder."""
|
24 |
+
output_list = []
|
25 |
+
file_list = []
|
26 |
+
# Get list of files in the folder
|
27 |
+
for file in os.listdir(folder_path):
|
28 |
+
if os.path.isfile(os.path.join(folder_path, file)):
|
29 |
+
file_list.append(file)
|
30 |
+
|
31 |
+
# Process each file
|
32 |
+
for file in file_list:
|
33 |
+
file_path = os.path.join(folder_path, file)
|
34 |
+
try:
|
35 |
+
with open(file_path, "r") as f:
|
36 |
+
output_list.extend([line.strip() for line in f.readlines()])
|
37 |
+
except Exception as e:
|
38 |
+
log.error(f"Error reading file {file}: {str(e)}")
|
39 |
+
|
40 |
+
return output_list
|
41 |
+
|
42 |
+
|
43 |
+
def to_ascii(prompt: str) -> str:
|
44 |
+
"""Convert prompt to ASCII."""
|
45 |
+
return re.sub(r"[^\x00-\x7F]+", " ", prompt)
|
cosmos_transfer1/auxiliary/guardrail/common/__init__.py
ADDED
File without changes
|
cosmos_transfer1/auxiliary/guardrail/common/core.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, Tuple
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from cosmos_transfer1.utils import log
|
21 |
+
|
22 |
+
|
23 |
+
class ContentSafetyGuardrail:
|
24 |
+
def is_safe(self, **kwargs) -> Tuple[bool, str]:
|
25 |
+
raise NotImplementedError("Child classes must implement the is_safe method")
|
26 |
+
|
27 |
+
|
28 |
+
class PostprocessingGuardrail:
|
29 |
+
def postprocess(self, frames: np.ndarray) -> np.ndarray:
|
30 |
+
raise NotImplementedError("Child classes must implement the postprocess method")
|
31 |
+
|
32 |
+
|
33 |
+
class GuardrailRunner:
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
safety_models: list[ContentSafetyGuardrail] | None = None,
|
37 |
+
generic_block_msg: str = "",
|
38 |
+
generic_safe_msg: str = "",
|
39 |
+
postprocessors: list[PostprocessingGuardrail] | None = None,
|
40 |
+
):
|
41 |
+
self.safety_models = safety_models
|
42 |
+
self.generic_block_msg = generic_block_msg
|
43 |
+
self.generic_safe_msg = generic_safe_msg if generic_safe_msg else "Prompt is safe"
|
44 |
+
self.postprocessors = postprocessors
|
45 |
+
|
46 |
+
def run_safety_check(self, input: Any) -> Tuple[bool, str]:
|
47 |
+
"""Run the safety check on the input."""
|
48 |
+
if not self.safety_models:
|
49 |
+
log.warning("No safety models found, returning safe")
|
50 |
+
return True, self.generic_safe_msg
|
51 |
+
|
52 |
+
for guardrail in self.safety_models:
|
53 |
+
guardrail_name = str(guardrail.__class__.__name__).upper()
|
54 |
+
log.debug(f"Running guardrail: {guardrail_name}")
|
55 |
+
safe, message = guardrail.is_safe(input)
|
56 |
+
if not safe:
|
57 |
+
reasoning = self.generic_block_msg if self.generic_block_msg else f"{guardrail_name}: {message}"
|
58 |
+
return False, reasoning
|
59 |
+
return True, self.generic_safe_msg
|
60 |
+
|
61 |
+
def postprocess(self, frames: np.ndarray) -> np.ndarray:
|
62 |
+
"""Run the postprocessing on the video frames."""
|
63 |
+
if not self.postprocessors:
|
64 |
+
log.warning("No postprocessors found, returning original frames")
|
65 |
+
return frames
|
66 |
+
|
67 |
+
for guardrail in self.postprocessors:
|
68 |
+
guardrail_name = str(guardrail.__class__.__name__).upper()
|
69 |
+
log.debug(f"Running guardrail: {guardrail_name}")
|
70 |
+
frames = guardrail.postprocess(frames)
|
71 |
+
return frames
|
cosmos_transfer1/auxiliary/guardrail/common/io_utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import glob
|
17 |
+
from dataclasses import dataclass
|
18 |
+
|
19 |
+
import imageio
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
from cosmos_transfer1.utils import log
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class VideoData:
|
27 |
+
frames: np.ndarray # Shape: [B, H, W, C]
|
28 |
+
fps: int
|
29 |
+
duration: int # in seconds
|
30 |
+
|
31 |
+
|
32 |
+
def get_video_filepaths(input_dir: str) -> list[str]:
|
33 |
+
"""Get a list of filepaths for all videos in the input directory."""
|
34 |
+
paths = glob.glob(f"{input_dir}/**/*.mp4", recursive=True)
|
35 |
+
paths += glob.glob(f"{input_dir}/**/*.avi", recursive=True)
|
36 |
+
paths += glob.glob(f"{input_dir}/**/*.mov", recursive=True)
|
37 |
+
paths = sorted(paths)
|
38 |
+
log.debug(f"Found {len(paths)} videos")
|
39 |
+
return paths
|
40 |
+
|
41 |
+
|
42 |
+
def read_video(filepath: str) -> VideoData:
|
43 |
+
"""Read a video file and extract its frames and metadata."""
|
44 |
+
try:
|
45 |
+
reader = imageio.get_reader(filepath, "ffmpeg")
|
46 |
+
except Exception as e:
|
47 |
+
raise ValueError(f"Failed to read video file: {filepath}") from e
|
48 |
+
|
49 |
+
# Extract metadata from the video file
|
50 |
+
try:
|
51 |
+
metadata = reader.get_meta_data()
|
52 |
+
fps = metadata.get("fps")
|
53 |
+
duration = metadata.get("duration")
|
54 |
+
except Exception as e:
|
55 |
+
reader.close()
|
56 |
+
raise ValueError(f"Failed to extract metadata from video file: {filepath}") from e
|
57 |
+
|
58 |
+
# Extract frames from the video file
|
59 |
+
try:
|
60 |
+
frames = np.array([frame for frame in reader])
|
61 |
+
except Exception as e:
|
62 |
+
raise ValueError(f"Failed to extract frames from video file: {filepath}") from e
|
63 |
+
finally:
|
64 |
+
reader.close()
|
65 |
+
|
66 |
+
return VideoData(frames=frames, fps=fps, duration=duration)
|
67 |
+
|
68 |
+
|
69 |
+
def save_video(filepath: str, frames: np.ndarray, fps: int) -> None:
|
70 |
+
"""Save a video file from a sequence of frames."""
|
71 |
+
try:
|
72 |
+
writer = imageio.get_writer(filepath, fps=fps, macro_block_size=1)
|
73 |
+
for frame in frames:
|
74 |
+
writer.append_data(frame)
|
75 |
+
except Exception as e:
|
76 |
+
raise ValueError(f"Failed to save video file to {filepath}") from e
|
77 |
+
finally:
|
78 |
+
writer.close()
|
cosmos_transfer1/auxiliary/guardrail/common/presets.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from cosmos_transfer1.auxiliary.guardrail.blocklist.blocklist import Blocklist
|
21 |
+
from cosmos_transfer1.auxiliary.guardrail.common.core import GuardrailRunner
|
22 |
+
from cosmos_transfer1.auxiliary.guardrail.face_blur_filter.face_blur_filter import RetinaFaceFilter
|
23 |
+
from cosmos_transfer1.auxiliary.guardrail.llamaGuard3.llamaGuard3 import LlamaGuard3
|
24 |
+
from cosmos_transfer1.auxiliary.guardrail.video_content_safety_filter.video_content_safety_filter import (
|
25 |
+
VideoContentSafetyFilter,
|
26 |
+
)
|
27 |
+
from cosmos_transfer1.utils import log
|
28 |
+
|
29 |
+
|
30 |
+
def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
|
31 |
+
"""Create the text guardrail runner."""
|
32 |
+
return GuardrailRunner(safety_models=[Blocklist(checkpoint_dir), LlamaGuard3(checkpoint_dir)])
|
33 |
+
|
34 |
+
|
35 |
+
def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
|
36 |
+
"""Create the video guardrail runner."""
|
37 |
+
return GuardrailRunner(
|
38 |
+
safety_models=[VideoContentSafetyFilter(checkpoint_dir)],
|
39 |
+
postprocessors=[RetinaFaceFilter(checkpoint_dir)],
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool:
|
44 |
+
"""Run the text guardrail on the prompt, checking for content safety.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
prompt: The text prompt.
|
48 |
+
guardrail_runner: The text guardrail runner.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
bool: Whether the prompt is safe.
|
52 |
+
"""
|
53 |
+
is_safe, message = guardrail_runner.run_safety_check(prompt)
|
54 |
+
if not is_safe:
|
55 |
+
log.critical(f"GUARDRAIL BLOCKED: {message}")
|
56 |
+
return is_safe
|
57 |
+
|
58 |
+
|
59 |
+
def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None:
|
60 |
+
"""Run the video guardrail on the frames, checking for content safety and applying face blur.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
frames: The frames of the generated video.
|
64 |
+
guardrail_runner: The video guardrail runner.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
The processed frames if safe, otherwise None.
|
68 |
+
"""
|
69 |
+
is_safe, message = guardrail_runner.run_safety_check(frames)
|
70 |
+
if not is_safe:
|
71 |
+
log.critical(f"GUARDRAIL BLOCKED: {message}")
|
72 |
+
return None
|
73 |
+
|
74 |
+
frames = guardrail_runner.postprocess(frames)
|
75 |
+
return frames
|
cosmos_transfer1/auxiliary/guardrail/face_blur_filter/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_transfer1/auxiliary/guardrail/face_blur_filter/blur_utils.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
|
20 |
+
def pixelate_face(face_img: np.ndarray, blocks: int = 5) -> np.ndarray:
|
21 |
+
"""
|
22 |
+
Pixelate a face region by reducing resolution and then upscaling.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
face_img: Face region to pixelate
|
26 |
+
blocks: Number of blocks to divide the face into (in each dimension)
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
Pixelated face region
|
30 |
+
"""
|
31 |
+
h, w = face_img.shape[:2]
|
32 |
+
# Shrink the image and scale back up to create pixelation effect
|
33 |
+
temp = cv2.resize(face_img, (blocks, blocks), interpolation=cv2.INTER_LINEAR)
|
34 |
+
pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST)
|
35 |
+
return pixelated
|
cosmos_transfer1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import os
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from retinaface.data import cfg_re50
|
22 |
+
from retinaface.layers.functions.prior_box import PriorBox
|
23 |
+
from retinaface.models.retinaface import RetinaFace
|
24 |
+
from torch.utils.data import DataLoader, TensorDataset
|
25 |
+
from tqdm import tqdm
|
26 |
+
|
27 |
+
from cosmos_transfer1.auxiliary.guardrail.common.core import GuardrailRunner, PostprocessingGuardrail
|
28 |
+
from cosmos_transfer1.auxiliary.guardrail.common.io_utils import get_video_filepaths, read_video, save_video
|
29 |
+
from cosmos_transfer1.auxiliary.guardrail.face_blur_filter.blur_utils import pixelate_face
|
30 |
+
from cosmos_transfer1.auxiliary.guardrail.face_blur_filter.retinaface_utils import (
|
31 |
+
decode_batch,
|
32 |
+
filter_detected_boxes,
|
33 |
+
load_model,
|
34 |
+
)
|
35 |
+
from cosmos_transfer1.utils import log, misc
|
36 |
+
|
37 |
+
# RetinaFace model constants from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
|
38 |
+
TOP_K = 5_000
|
39 |
+
KEEP_TOP_K = 750
|
40 |
+
NMS_THRESHOLD = 0.4
|
41 |
+
|
42 |
+
|
43 |
+
class RetinaFaceFilter(PostprocessingGuardrail):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
checkpoint_dir: str,
|
47 |
+
batch_size: int = 1,
|
48 |
+
confidence_threshold: float = 0.7,
|
49 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
50 |
+
) -> None:
|
51 |
+
"""
|
52 |
+
Initialize the RetinaFace model for face detection and blurring.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
checkpoint: Path to the RetinaFace checkpoint file
|
56 |
+
batch_size: Batch size for RetinaFace inference and processing
|
57 |
+
confidence_threshold: Minimum confidence score to consider a face detection
|
58 |
+
"""
|
59 |
+
self.checkpoint = f"{checkpoint_dir}/nvidia/Cosmos-Guardrail1/face_blur_filter/Resnet50_Final.pth"
|
60 |
+
self.cfg = cfg_re50
|
61 |
+
self.batch_size = batch_size
|
62 |
+
self.confidence_threshold = confidence_threshold
|
63 |
+
self.device = device
|
64 |
+
self.dtype = torch.float32
|
65 |
+
|
66 |
+
# Disable loading ResNet pretrained weights
|
67 |
+
self.cfg["pretrain"] = False
|
68 |
+
self.net = RetinaFace(cfg=self.cfg, phase="test")
|
69 |
+
cpu = self.device == "cpu"
|
70 |
+
|
71 |
+
# Load from RetinaFace pretrained checkpoint
|
72 |
+
self.net = load_model(self.net, self.checkpoint, cpu)
|
73 |
+
self.net.to(self.device, dtype=self.dtype).eval()
|
74 |
+
|
75 |
+
def preprocess_frames(self, frames: np.ndarray) -> torch.Tensor:
|
76 |
+
"""Preprocess a sequence of frames for face detection.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
frames: Input frames
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
Preprocessed frames tensor
|
83 |
+
"""
|
84 |
+
with torch.no_grad():
|
85 |
+
frames_tensor = torch.from_numpy(frames).to(self.device, dtype=self.dtype) # Shape: [T, H, W, C]
|
86 |
+
frames_tensor = frames_tensor.permute(0, 3, 1, 2) # Shape: [T, C, H, W]
|
87 |
+
frames_tensor = frames_tensor[:, [2, 1, 0], :, :] # RGB to BGR to match RetinaFace model input
|
88 |
+
means = torch.tensor([104.0, 117.0, 123.0], device=self.device, dtype=self.dtype).view(1, 3, 1, 1)
|
89 |
+
frames_tensor = frames_tensor - means # Subtract mean BGR values for each channel
|
90 |
+
return frames_tensor
|
91 |
+
|
92 |
+
def blur_detected_faces(
|
93 |
+
self,
|
94 |
+
frames: np.ndarray,
|
95 |
+
batch_loc: torch.Tensor,
|
96 |
+
batch_conf: torch.Tensor,
|
97 |
+
prior_data: torch.Tensor,
|
98 |
+
scale: torch.Tensor,
|
99 |
+
min_size: tuple[int] = (20, 20),
|
100 |
+
) -> list[np.ndarray]:
|
101 |
+
"""Blur detected faces in a batch of frames using RetinaFace predictions.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
frames: Input frames
|
105 |
+
batch_loc: Batched location predictions
|
106 |
+
batch_conf: Batched confidence scores
|
107 |
+
prior_data: Prior boxes for the video
|
108 |
+
scale: Scale factor for resizing detections
|
109 |
+
min_size: Minimum size of a detected face region in pixels
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
Processed frames with pixelated faces
|
113 |
+
"""
|
114 |
+
with torch.no_grad():
|
115 |
+
batch_boxes = decode_batch(batch_loc, prior_data, self.cfg["variance"])
|
116 |
+
batch_boxes = batch_boxes * scale
|
117 |
+
|
118 |
+
blurred_frames = []
|
119 |
+
for i, boxes in enumerate(batch_boxes):
|
120 |
+
boxes = boxes.detach().cpu().numpy()
|
121 |
+
scores = batch_conf[i, :, 1].detach().cpu().numpy()
|
122 |
+
|
123 |
+
filtered_boxes = filter_detected_boxes(
|
124 |
+
boxes,
|
125 |
+
scores,
|
126 |
+
confidence_threshold=self.confidence_threshold,
|
127 |
+
nms_threshold=NMS_THRESHOLD,
|
128 |
+
top_k=TOP_K,
|
129 |
+
keep_top_k=KEEP_TOP_K,
|
130 |
+
)
|
131 |
+
|
132 |
+
frame = frames[i]
|
133 |
+
for box in filtered_boxes:
|
134 |
+
x1, y1, x2, y2 = map(int, box)
|
135 |
+
# Ignore bounding boxes smaller than the minimum size
|
136 |
+
if x2 - x1 < min_size[0] or y2 - y1 < min_size[1]:
|
137 |
+
continue
|
138 |
+
max_h, max_w = frame.shape[:2]
|
139 |
+
face_roi = frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)]
|
140 |
+
blurred_face = pixelate_face(face_roi)
|
141 |
+
frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] = blurred_face
|
142 |
+
blurred_frames.append(frame)
|
143 |
+
|
144 |
+
return blurred_frames
|
145 |
+
|
146 |
+
def postprocess(self, frames: np.ndarray) -> np.ndarray:
|
147 |
+
"""Blur faces in a sequence of frames.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
frames: Input frames
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
Processed frames with pixelated faces
|
154 |
+
"""
|
155 |
+
# Create dataset and dataloader
|
156 |
+
frames_tensor = self.preprocess_frames(frames)
|
157 |
+
dataset = TensorDataset(frames_tensor)
|
158 |
+
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
|
159 |
+
processed_frames, processed_batches = [], []
|
160 |
+
|
161 |
+
prior_data, scale = None, None
|
162 |
+
for i, batch in enumerate(dataloader):
|
163 |
+
batch = batch[0]
|
164 |
+
h, w = batch.shape[-2:] # Batch shape: [C, H, W]
|
165 |
+
|
166 |
+
with torch.no_grad():
|
167 |
+
# Generate priors for the video
|
168 |
+
if prior_data is None:
|
169 |
+
priorbox = PriorBox(self.cfg, image_size=(h, w))
|
170 |
+
priors = priorbox.forward()
|
171 |
+
priors = priors.to(self.device, dtype=self.dtype)
|
172 |
+
prior_data = priors.data
|
173 |
+
|
174 |
+
# Get scale for resizing detections
|
175 |
+
if scale is None:
|
176 |
+
scale = torch.Tensor([w, h, w, h])
|
177 |
+
scale = scale.to(self.device, dtype=self.dtype)
|
178 |
+
|
179 |
+
batch_loc, batch_conf, _ = self.net(batch)
|
180 |
+
|
181 |
+
# Blur detected faces in each batch of frames
|
182 |
+
start_idx = i * self.batch_size
|
183 |
+
end_idx = min(start_idx + self.batch_size, len(frames))
|
184 |
+
processed_batches.append(
|
185 |
+
self.blur_detected_faces(frames[start_idx:end_idx], batch_loc, batch_conf, prior_data, scale)
|
186 |
+
)
|
187 |
+
|
188 |
+
processed_frames = [frame for batch in processed_batches for frame in batch]
|
189 |
+
return np.array(processed_frames)
|
190 |
+
|
191 |
+
|
192 |
+
def parse_args():
|
193 |
+
parser = argparse.ArgumentParser()
|
194 |
+
parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos")
|
195 |
+
parser.add_argument("--output_dir", type=str, required=True, help="Path for saving processed videos")
|
196 |
+
parser.add_argument(
|
197 |
+
"--checkpoint",
|
198 |
+
type=str,
|
199 |
+
help="Path to the RetinaFace checkpoint file",
|
200 |
+
)
|
201 |
+
return parser.parse_args()
|
202 |
+
|
203 |
+
|
204 |
+
def main(args):
|
205 |
+
filepaths = get_video_filepaths(args.input_dir)
|
206 |
+
if not filepaths:
|
207 |
+
log.error(f"No video files found in directory: {args.input_dir}")
|
208 |
+
return
|
209 |
+
|
210 |
+
face_blur = RetinaFaceFilter(checkpoint=args.checkpoint)
|
211 |
+
postprocessing_runner = GuardrailRunner(postprocessors=[face_blur])
|
212 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
213 |
+
|
214 |
+
for filepath in tqdm(filepaths):
|
215 |
+
video_data = read_video(filepath)
|
216 |
+
with misc.timer("face blur filter"):
|
217 |
+
frames = postprocessing_runner.postprocess(video_data.frames)
|
218 |
+
|
219 |
+
output_path = os.path.join(args.output_dir, os.path.basename(filepath))
|
220 |
+
save_video(output_path, frames, video_data.fps)
|
221 |
+
|
222 |
+
|
223 |
+
if __name__ == "__main__":
|
224 |
+
args = parse_args()
|
225 |
+
main(args)
|
cosmos_transfer1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from retinaface.utils.nms.py_cpu_nms import py_cpu_nms
|
19 |
+
|
20 |
+
from cosmos_transfer1.utils import log
|
21 |
+
|
22 |
+
|
23 |
+
# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
|
24 |
+
def filter_detected_boxes(boxes, scores, confidence_threshold, nms_threshold, top_k, keep_top_k):
|
25 |
+
"""Filter boxes based on confidence score and remove overlapping boxes using NMS."""
|
26 |
+
# Keep detections with confidence above threshold
|
27 |
+
inds = np.where(scores > confidence_threshold)[0]
|
28 |
+
boxes = boxes[inds]
|
29 |
+
scores = scores[inds]
|
30 |
+
|
31 |
+
# Sort by confidence and keep top K detections
|
32 |
+
order = scores.argsort()[::-1][:top_k]
|
33 |
+
boxes = boxes[order]
|
34 |
+
scores = scores[order]
|
35 |
+
|
36 |
+
# Run non-maximum-suppression (NMS) to remove overlapping boxes
|
37 |
+
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
38 |
+
keep = py_cpu_nms(dets, nms_threshold)
|
39 |
+
dets = dets[keep, :]
|
40 |
+
dets = dets[:keep_top_k, :]
|
41 |
+
boxes = dets[:, :-1]
|
42 |
+
return boxes
|
43 |
+
|
44 |
+
|
45 |
+
# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/utils/box_utils.py to handle batched inputs
|
46 |
+
def decode_batch(loc, priors, variances):
|
47 |
+
"""Decode batched locations from predictions using priors and variances.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
loc (tensor): Batched location predictions for loc layers.
|
51 |
+
Shape: [batch_size, num_priors, 4]
|
52 |
+
priors (tensor): Prior boxes in center-offset form.
|
53 |
+
Shape: [num_priors, 4]
|
54 |
+
variances: (list[float]): Variances of prior boxes.
|
55 |
+
|
56 |
+
Return:
|
57 |
+
Decoded batched bounding box predictions
|
58 |
+
Shape: [batch_size, num_priors, 4]
|
59 |
+
"""
|
60 |
+
batch_size = loc.size(0)
|
61 |
+
priors = priors.unsqueeze(0).expand(batch_size, -1, -1)
|
62 |
+
|
63 |
+
boxes = torch.cat(
|
64 |
+
(
|
65 |
+
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
66 |
+
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1]),
|
67 |
+
),
|
68 |
+
dim=2,
|
69 |
+
)
|
70 |
+
|
71 |
+
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
72 |
+
boxes[:, :, 2:] += boxes[:, :, :2]
|
73 |
+
return boxes
|
74 |
+
|
75 |
+
|
76 |
+
# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
|
77 |
+
def _check_keys(model, pretrained_state_dict):
|
78 |
+
ckpt_keys = set(pretrained_state_dict.keys())
|
79 |
+
model_keys = set(model.state_dict().keys())
|
80 |
+
used_pretrained_keys = model_keys & ckpt_keys
|
81 |
+
unused_pretrained_keys = ckpt_keys - model_keys
|
82 |
+
missing_keys = model_keys - ckpt_keys
|
83 |
+
log.debug("Missing keys:{}".format(len(missing_keys)))
|
84 |
+
log.debug("Unused checkpoint keys:{}".format(len(unused_pretrained_keys)))
|
85 |
+
log.debug("Used keys:{}".format(len(used_pretrained_keys)))
|
86 |
+
assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint"
|
87 |
+
return True
|
88 |
+
|
89 |
+
|
90 |
+
# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
|
91 |
+
def _remove_prefix(state_dict, prefix):
|
92 |
+
"""Old version of the model is stored with all names of parameters sharing common prefix 'module.'"""
|
93 |
+
log.debug("Removing prefix '{}'".format(prefix))
|
94 |
+
|
95 |
+
def f(x):
|
96 |
+
return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
|
97 |
+
|
98 |
+
return {f(key): value for key, value in state_dict.items()}
|
99 |
+
|
100 |
+
|
101 |
+
# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
|
102 |
+
def load_model(model, pretrained_path, load_to_cpu):
|
103 |
+
log.debug("Loading pretrained model from {}".format(pretrained_path))
|
104 |
+
if load_to_cpu:
|
105 |
+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage, weights_only=True)
|
106 |
+
else:
|
107 |
+
device = torch.cuda.current_device()
|
108 |
+
pretrained_dict = torch.load(
|
109 |
+
pretrained_path, map_location=lambda storage, loc: storage.cuda(device), weights_only=True
|
110 |
+
)
|
111 |
+
if "state_dict" in pretrained_dict.keys():
|
112 |
+
pretrained_dict = _remove_prefix(pretrained_dict["state_dict"], "module.")
|
113 |
+
else:
|
114 |
+
pretrained_dict = _remove_prefix(pretrained_dict, "module.")
|
115 |
+
_check_keys(model, pretrained_dict)
|
116 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
117 |
+
return model
|
cosmos_transfer1/auxiliary/guardrail/llamaGuard3/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_transfer1/auxiliary/guardrail/llamaGuard3/categories.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
UNSAFE_CATEGORIES = {
|
17 |
+
"S1": "Violent Crimes.",
|
18 |
+
"S2": "Non-Violent Crimes.",
|
19 |
+
"S3": "Sex Crimes.",
|
20 |
+
"S4": "Child Exploitation.",
|
21 |
+
"S5": "Defamation.",
|
22 |
+
"S6": "Specialized Advice.",
|
23 |
+
"S7": "Privacy.",
|
24 |
+
"S8": "Intellectual Property.",
|
25 |
+
"S9": "Indiscriminate Weapons.",
|
26 |
+
"S10": "Hate.",
|
27 |
+
"S11": "Self-Harm.",
|
28 |
+
"S12": "Sexual Content.",
|
29 |
+
"S13": "Elections.",
|
30 |
+
"s14": "Code Interpreter Abuse.",
|
31 |
+
}
|
cosmos_transfer1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import os
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
21 |
+
|
22 |
+
from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner
|
23 |
+
from cosmos_transfer1.auxiliary.guardrail.llamaGuard3.categories import UNSAFE_CATEGORIES
|
24 |
+
from cosmos_transfer1.utils import log, misc
|
25 |
+
|
26 |
+
SAFE = misc.Color.green("SAFE")
|
27 |
+
UNSAFE = misc.Color.red("UNSAFE")
|
28 |
+
|
29 |
+
|
30 |
+
class LlamaGuard3(ContentSafetyGuardrail):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
checkpoint_dir: str,
|
34 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
35 |
+
) -> None:
|
36 |
+
self.checkpoint_dir = checkpoint_dir
|
37 |
+
self.device = device
|
38 |
+
self.dtype = torch.bfloat16
|
39 |
+
|
40 |
+
model_id = "meta-llama/Llama-Guard-3-8B"
|
41 |
+
model_dir = os.path.join(self.checkpoint_dir, model_id)
|
42 |
+
|
43 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_dir)
|
44 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
45 |
+
|
46 |
+
self.model.to(self.device, dtype=self.dtype).eval()
|
47 |
+
|
48 |
+
def get_llamaGuard3_block_message(self, moderation_output: str) -> str:
|
49 |
+
"""Extract the blocked category from the Llama Guard 3 model output."""
|
50 |
+
block_msg = "Prompt blocked by Llama Guard 3."
|
51 |
+
try:
|
52 |
+
lines = moderation_output.splitlines()
|
53 |
+
categories_detected = []
|
54 |
+
for line in lines[1:]:
|
55 |
+
line_stripped = line.split("<|eot_id|>")[0].strip()
|
56 |
+
for catagory in line_stripped.split(","):
|
57 |
+
catagory = catagory.strip()
|
58 |
+
if catagory not in UNSAFE_CATEGORIES:
|
59 |
+
log.warning(f"Unrecognized category from moderation output: {catagory}")
|
60 |
+
else:
|
61 |
+
categories_detected.append(catagory)
|
62 |
+
if len(categories_detected) > 0:
|
63 |
+
blocked_catagories = ", ".join([UNSAFE_CATEGORIES[catagory][:-1] for catagory in categories_detected])
|
64 |
+
block_msg = f"{block_msg} Violations: {blocked_catagories}."
|
65 |
+
except Exception as e:
|
66 |
+
log.warning(f"Unable to extract blocked category from Llama Guard 3 output: {e}")
|
67 |
+
return block_msg
|
68 |
+
|
69 |
+
def filter_llamaGuard3_output(self, prompt: str) -> tuple[bool, str]:
|
70 |
+
"""Filter the Llama Guard 3 model output and return the safety status and message."""
|
71 |
+
conversation = [{"role": "user", "content": prompt}]
|
72 |
+
input_ids = self.tokenizer.apply_chat_template(
|
73 |
+
conversation, categories=UNSAFE_CATEGORIES, return_tensors="pt"
|
74 |
+
).to("cuda")
|
75 |
+
prompt_len = input_ids.shape[1]
|
76 |
+
output = self.model.generate(
|
77 |
+
input_ids=input_ids,
|
78 |
+
max_new_tokens=100,
|
79 |
+
return_dict_in_generate=True,
|
80 |
+
pad_token_id=0,
|
81 |
+
)
|
82 |
+
generated_tokens = output.sequences[:, prompt_len:]
|
83 |
+
moderation_output = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=False).strip()
|
84 |
+
|
85 |
+
if "unsafe" in moderation_output.lower():
|
86 |
+
block_msg = self.get_llamaGuard3_block_message(moderation_output)
|
87 |
+
return False, block_msg
|
88 |
+
else:
|
89 |
+
return True, ""
|
90 |
+
|
91 |
+
def is_safe(self, prompt: str) -> tuple[bool, str]:
|
92 |
+
"""Check if the input prompt is safe according to the Llama Guard 3 model."""
|
93 |
+
try:
|
94 |
+
return self.filter_llamaGuard3_output(prompt)
|
95 |
+
except Exception as e:
|
96 |
+
log.error(f"Unexpected error occurred when running Llama Guard 3 guardrail: {e}")
|
97 |
+
return True, "Unexpected error occurred when running Llama Guard 3 guardrail."
|
98 |
+
|
99 |
+
|
100 |
+
def parse_args():
|
101 |
+
parser = argparse.ArgumentParser()
|
102 |
+
parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
|
103 |
+
parser.add_argument(
|
104 |
+
"--checkpoint_dir",
|
105 |
+
type=str,
|
106 |
+
help="Path to the Llama Guard 3 checkpoint folder",
|
107 |
+
)
|
108 |
+
return parser.parse_args()
|
109 |
+
|
110 |
+
|
111 |
+
def main(args):
|
112 |
+
llamaGuard3 = LlamaGuard3(checkpoint_dir=args.checkpoint_dir)
|
113 |
+
runner = GuardrailRunner(safety_models=[llamaGuard3])
|
114 |
+
with misc.timer("Llama Guard 3 safety check"):
|
115 |
+
safety, message = runner.run_safety_check(args.prompt)
|
116 |
+
log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
|
117 |
+
log.info(f"Message: {message}") if not safety else None
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
args = parse_args()
|
122 |
+
main(args)
|
cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import attrs
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from cosmos_transfer1.utils.ddp_config import make_freezable
|
21 |
+
|
22 |
+
|
23 |
+
@make_freezable
|
24 |
+
@attrs.define(slots=False)
|
25 |
+
class ModelConfig:
|
26 |
+
input_size: int = 1152
|
27 |
+
num_classes: int = 7
|
28 |
+
|
29 |
+
|
30 |
+
class SafetyClassifier(nn.Module):
|
31 |
+
def __init__(self, input_size: int = 1024, num_classes: int = 2):
|
32 |
+
super().__init__()
|
33 |
+
self.input_size = input_size
|
34 |
+
self.num_classes = num_classes
|
35 |
+
self.layers = nn.Sequential(
|
36 |
+
nn.Linear(self.input_size, 512),
|
37 |
+
nn.BatchNorm1d(512),
|
38 |
+
nn.ReLU(),
|
39 |
+
nn.Linear(512, 256),
|
40 |
+
nn.BatchNorm1d(256),
|
41 |
+
nn.ReLU(),
|
42 |
+
nn.Linear(256, self.num_classes),
|
43 |
+
# Note: No activation function here; CrossEntropyLoss expects raw logits
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
return self.layers(x)
|
48 |
+
|
49 |
+
|
50 |
+
class VideoSafetyModel(nn.Module):
|
51 |
+
def __init__(self, config: ModelConfig) -> None:
|
52 |
+
super().__init__()
|
53 |
+
self.config = config
|
54 |
+
self.num_classes = config.num_classes
|
55 |
+
self.network = SafetyClassifier(input_size=config.input_size, num_classes=self.num_classes)
|
56 |
+
|
57 |
+
@torch.inference_mode()
|
58 |
+
def forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
59 |
+
logits = self.network(data_batch["data"].cuda())
|
60 |
+
return {"logits": logits}
|
cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
from typing import Iterable, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from PIL import Image
|
23 |
+
|
24 |
+
from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner
|
25 |
+
from cosmos_transfer1.auxiliary.guardrail.common.io_utils import get_video_filepaths, read_video
|
26 |
+
from cosmos_transfer1.auxiliary.guardrail.video_content_safety_filter.model import ModelConfig, VideoSafetyModel
|
27 |
+
from cosmos_transfer1.auxiliary.guardrail.video_content_safety_filter.vision_encoder import SigLIPEncoder
|
28 |
+
from cosmos_transfer1.utils import log, misc
|
29 |
+
|
30 |
+
# Define the class index to class name mapping for multi-class classification
|
31 |
+
CLASS_IDX_TO_NAME = {
|
32 |
+
0: "Safe",
|
33 |
+
1: "Sexual_Content",
|
34 |
+
3: "Drugs",
|
35 |
+
4: "Child_Abuse",
|
36 |
+
5: "Hate_and_Harassment",
|
37 |
+
6: "Self-Harm",
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
class VideoContentSafetyFilter(ContentSafetyGuardrail):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
checkpoint_dir: str,
|
45 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
46 |
+
) -> None:
|
47 |
+
self.checkpoint_dir = os.path.join(checkpoint_dir, "nvidia/Cosmos-Guardrail1/video_content_safety_filter")
|
48 |
+
self.device = device
|
49 |
+
self.dtype = torch.float32
|
50 |
+
|
51 |
+
# Initialize the SigLIP encoder
|
52 |
+
self.encoder = SigLIPEncoder(checkpoint_dir=self.checkpoint_dir, device=device, dtype=self.dtype)
|
53 |
+
|
54 |
+
# Use ModelConfig directly for inference configuration
|
55 |
+
model_config = ModelConfig(input_size=1152, num_classes=7)
|
56 |
+
|
57 |
+
# Load the multi-class classifier
|
58 |
+
self.model = VideoSafetyModel(model_config)
|
59 |
+
safety_filter_local_path = os.path.join(self.checkpoint_dir, "safety_filter.pt")
|
60 |
+
checkpoint = torch.load(safety_filter_local_path, map_location=torch.device("cpu"), weights_only=True)
|
61 |
+
self.model.load_state_dict(checkpoint["model"])
|
62 |
+
self.model.to(self.device, dtype=self.dtype).eval()
|
63 |
+
|
64 |
+
@torch.inference_mode()
|
65 |
+
def __infer(self, pil_image: Image.Image) -> int:
|
66 |
+
"""Infer the class of the image."""
|
67 |
+
image_embs = self.encoder.encode_image(pil_image)
|
68 |
+
logits = self.model.network(image_embs)
|
69 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
70 |
+
predicted_class = torch.argmax(probabilities, dim=-1).item()
|
71 |
+
return predicted_class
|
72 |
+
|
73 |
+
def is_safe_file(self, filepath: str) -> bool:
|
74 |
+
"""Check if the video file is safe."""
|
75 |
+
video_data = read_video(filepath)
|
76 |
+
|
77 |
+
# Sample frames at 2 FPS
|
78 |
+
sample_rate = 2 # frames per second
|
79 |
+
frame_interval = int(video_data.fps / sample_rate)
|
80 |
+
frame_numbers = list(range(0, int(video_data.fps * video_data.duration), frame_interval))
|
81 |
+
|
82 |
+
is_safe = True
|
83 |
+
frame_scores = []
|
84 |
+
|
85 |
+
for frame_number in frame_numbers:
|
86 |
+
try:
|
87 |
+
frame = video_data.frames[frame_number]
|
88 |
+
pil_image = Image.fromarray(frame)
|
89 |
+
predicted_class = self.__infer(pil_image)
|
90 |
+
class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Safe")
|
91 |
+
frame_scores.append({"frame_number": frame_number, "class": class_name})
|
92 |
+
|
93 |
+
# If any frame is not "Safe", mark the video as unsafe
|
94 |
+
if class_name != "Safe":
|
95 |
+
is_safe = False
|
96 |
+
break
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}")
|
100 |
+
continue
|
101 |
+
|
102 |
+
# Prepare data for JSON
|
103 |
+
video_data = {
|
104 |
+
"filepath": filepath,
|
105 |
+
"is_safe": is_safe,
|
106 |
+
"video_length": video_data.duration,
|
107 |
+
"fps": video_data.fps,
|
108 |
+
"frame_scores": frame_scores,
|
109 |
+
}
|
110 |
+
|
111 |
+
log.info(f"Video {filepath} is {'SAFE' if is_safe else 'UNSAFE'}.")
|
112 |
+
log.debug(f"Video data: {json.dumps(video_data, indent=4)}")
|
113 |
+
return is_safe
|
114 |
+
|
115 |
+
def is_safe_frames(self, frames: Iterable) -> bool:
|
116 |
+
"""Check if the generated video frames are safe."""
|
117 |
+
frame_scores = []
|
118 |
+
total_frames = 0
|
119 |
+
safe_frames = 0
|
120 |
+
|
121 |
+
for frame_number, frame in enumerate(frames):
|
122 |
+
try:
|
123 |
+
total_frames += 1
|
124 |
+
pil_image = Image.fromarray(frame)
|
125 |
+
predicted_class = self.__infer(pil_image)
|
126 |
+
class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Safe")
|
127 |
+
frame_scores.append({"frame_number": frame_number, "class": class_name})
|
128 |
+
|
129 |
+
if class_name == "Safe":
|
130 |
+
safe_frames += 1
|
131 |
+
|
132 |
+
except Exception as e:
|
133 |
+
log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}")
|
134 |
+
continue
|
135 |
+
|
136 |
+
# Decide if the video is safe based on the ratio of safe frames
|
137 |
+
is_safe = False
|
138 |
+
if total_frames > 0:
|
139 |
+
is_safe = (safe_frames / total_frames) >= 0.95
|
140 |
+
|
141 |
+
video_data = {
|
142 |
+
"is_safe": is_safe,
|
143 |
+
"frame_scores": frame_scores,
|
144 |
+
}
|
145 |
+
|
146 |
+
log.debug(f"Frames data: {json.dumps(video_data, indent=4)}")
|
147 |
+
return is_safe
|
148 |
+
|
149 |
+
def is_safe(self, input: Union[str, Iterable]) -> Tuple[bool, str]:
|
150 |
+
if isinstance(input, str):
|
151 |
+
is_safe = self.is_safe_file(input)
|
152 |
+
return is_safe, "safe video detected" if is_safe else "unsafe video detected"
|
153 |
+
else:
|
154 |
+
is_safe = self.is_safe_frames(input)
|
155 |
+
return is_safe, "safe frames detected" if is_safe else "unsafe frames detected"
|
156 |
+
|
157 |
+
|
158 |
+
def parse_args():
|
159 |
+
parser = argparse.ArgumentParser()
|
160 |
+
parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos")
|
161 |
+
parser.add_argument(
|
162 |
+
"--checkpoint_dir",
|
163 |
+
type=str,
|
164 |
+
help="Path to the Video Content Safety Filter checkpoint folder",
|
165 |
+
)
|
166 |
+
return parser.parse_args()
|
167 |
+
|
168 |
+
|
169 |
+
def main(args):
|
170 |
+
filepaths = get_video_filepaths(args.input_dir)
|
171 |
+
if not filepaths:
|
172 |
+
log.error(f"No video files found in directory: {args.input_dir}")
|
173 |
+
return
|
174 |
+
|
175 |
+
video_filter = VideoContentSafetyFilter(checkpoint_dir=args.checkpoint_dir)
|
176 |
+
runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe")
|
177 |
+
|
178 |
+
for filepath in filepaths:
|
179 |
+
with misc.timer("video content safety filter"):
|
180 |
+
_ = runner.run_safety_check(filepath)
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == "__main__":
|
184 |
+
args = parse_args()
|
185 |
+
main(args)
|
cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from PIL import Image
|
20 |
+
from transformers import SiglipModel, SiglipProcessor
|
21 |
+
|
22 |
+
|
23 |
+
class SigLIPEncoder(torch.nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
checkpoint_dir: str,
|
27 |
+
model_name: str = "google/siglip-so400m-patch14-384",
|
28 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
29 |
+
dtype=torch.float32,
|
30 |
+
) -> None:
|
31 |
+
super().__init__()
|
32 |
+
self.checkpoint_dir = checkpoint_dir
|
33 |
+
self.device = device
|
34 |
+
self.dtype = dtype
|
35 |
+
self.model = SiglipModel.from_pretrained(model_name, cache_dir=self.checkpoint_dir)
|
36 |
+
self.processor = SiglipProcessor.from_pretrained(model_name, cache_dir=self.checkpoint_dir)
|
37 |
+
self.model.to(self.device, dtype=self.dtype).eval()
|
38 |
+
|
39 |
+
@torch.inference_mode()
|
40 |
+
def encode_image(self, input_img: Image.Image) -> torch.Tensor:
|
41 |
+
"""Encode an image into a feature vector."""
|
42 |
+
with torch.no_grad():
|
43 |
+
inputs = self.processor(images=input_img, return_tensors="pt").to(self.device, dtype=self.dtype)
|
44 |
+
image_features = self.model.get_image_features(**inputs)
|
45 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
46 |
+
return image_features
|
cosmos_transfer1/auxiliary/human_keypoint/human_keypoint.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
from rtmlib import Wholebody
|
21 |
+
|
22 |
+
from cosmos_transfer1.diffusion.datasets.augmentors.human_keypoint_utils import (
|
23 |
+
coco_wholebody_133_skeleton,
|
24 |
+
openpose134_skeleton,
|
25 |
+
)
|
26 |
+
from cosmos_transfer1.utils import log
|
27 |
+
|
28 |
+
|
29 |
+
class HumanKeypointModel:
|
30 |
+
def __init__(self, to_openpose=True, conf_thres=0.6):
|
31 |
+
self.model = Wholebody(
|
32 |
+
to_openpose=to_openpose,
|
33 |
+
mode="performance",
|
34 |
+
backend="onnxruntime",
|
35 |
+
device="cuda",
|
36 |
+
)
|
37 |
+
self.to_openpose = to_openpose
|
38 |
+
self.conf_thres = conf_thres
|
39 |
+
|
40 |
+
def __call__(self, input_video: str, output_video: str = "keypoint.mp4") -> str:
|
41 |
+
"""
|
42 |
+
Generate the human body keypoint plot for the keypointControlNet video2world model.
|
43 |
+
Input: mp4 video
|
44 |
+
Output: mp4 keypoint video, of the same spatial and temporal dimensions as the input video.
|
45 |
+
"""
|
46 |
+
|
47 |
+
log.info(f"Processing video: {input_video} to generate keypoint video: {output_video}")
|
48 |
+
assert os.path.exists(input_video)
|
49 |
+
|
50 |
+
cap = cv2.VideoCapture(input_video)
|
51 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
52 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
53 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
54 |
+
frame_size = (frame_width, frame_height)
|
55 |
+
|
56 |
+
# vid writer
|
57 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
58 |
+
skeleton_writer = cv2.VideoWriter(output_video, fourcc, fps, frame_size)
|
59 |
+
|
60 |
+
log.info(f"frame width: {frame_width}, frame height: {frame_height}, fps: {fps}")
|
61 |
+
log.info("start pose estimation for frames..")
|
62 |
+
|
63 |
+
# Process each frame
|
64 |
+
while cap.isOpened():
|
65 |
+
ret, frame = cap.read()
|
66 |
+
if not ret:
|
67 |
+
break
|
68 |
+
|
69 |
+
# Create a black background frame
|
70 |
+
black_frame = np.zeros_like(frame)
|
71 |
+
|
72 |
+
# Run pose estimation
|
73 |
+
keypoints, scores = self.model(frame)
|
74 |
+
|
75 |
+
if keypoints is not None and len(keypoints) > 0:
|
76 |
+
skeleton_frame = self.plot_person_kpts(
|
77 |
+
black_frame,
|
78 |
+
keypoints,
|
79 |
+
scores,
|
80 |
+
kpt_thr=self.conf_thres,
|
81 |
+
openpose_format=True,
|
82 |
+
line_width=4,
|
83 |
+
) # (h, w, 3)
|
84 |
+
else:
|
85 |
+
skeleton_frame = black_frame
|
86 |
+
|
87 |
+
skeleton_writer.write(skeleton_frame[:, :, ::-1])
|
88 |
+
|
89 |
+
cap.release()
|
90 |
+
skeleton_writer.release()
|
91 |
+
|
92 |
+
def draw_skeleton(
|
93 |
+
self,
|
94 |
+
img: np.ndarray,
|
95 |
+
keypoints: np.ndarray,
|
96 |
+
scores: np.ndarray,
|
97 |
+
kpt_thr: float = 0.6,
|
98 |
+
openpose_format: bool = True,
|
99 |
+
radius: int = 2,
|
100 |
+
line_width: int = 4,
|
101 |
+
):
|
102 |
+
skeleton_topology = openpose134_skeleton if openpose_format else coco_wholebody_133_skeleton
|
103 |
+
assert len(keypoints.shape) == 2
|
104 |
+
keypoint_info, skeleton_info = (
|
105 |
+
skeleton_topology["keypoint_info"],
|
106 |
+
skeleton_topology["skeleton_info"],
|
107 |
+
)
|
108 |
+
vis_kpt = [s >= kpt_thr for s in scores]
|
109 |
+
link_dict = {}
|
110 |
+
for i, kpt_info in keypoint_info.items():
|
111 |
+
kpt_color = tuple(kpt_info["color"])
|
112 |
+
link_dict[kpt_info["name"]] = kpt_info["id"]
|
113 |
+
|
114 |
+
kpt = keypoints[i]
|
115 |
+
|
116 |
+
if vis_kpt[i]:
|
117 |
+
img = cv2.circle(img, (int(kpt[0]), int(kpt[1])), int(radius), kpt_color, -1)
|
118 |
+
|
119 |
+
for i, ske_info in skeleton_info.items():
|
120 |
+
link = ske_info["link"]
|
121 |
+
pt0, pt1 = link_dict[link[0]], link_dict[link[1]]
|
122 |
+
|
123 |
+
if vis_kpt[pt0] and vis_kpt[pt1]:
|
124 |
+
link_color = ske_info["color"]
|
125 |
+
kpt0 = keypoints[pt0]
|
126 |
+
kpt1 = keypoints[pt1]
|
127 |
+
|
128 |
+
img = cv2.line(
|
129 |
+
img, (int(kpt0[0]), int(kpt0[1])), (int(kpt1[0]), int(kpt1[1])), link_color, thickness=line_width
|
130 |
+
)
|
131 |
+
|
132 |
+
return img
|
133 |
+
|
134 |
+
def plot_person_kpts(
|
135 |
+
self,
|
136 |
+
pose_vis_img: np.ndarray,
|
137 |
+
keypoints: np.ndarray,
|
138 |
+
scores: np.ndarray,
|
139 |
+
kpt_thr: float = 0.6,
|
140 |
+
openpose_format: bool = True,
|
141 |
+
line_width: int = 4,
|
142 |
+
) -> np.ndarray:
|
143 |
+
"""
|
144 |
+
plot a single person
|
145 |
+
in-place update the pose image
|
146 |
+
"""
|
147 |
+
for kpts, ss in zip(keypoints, scores):
|
148 |
+
try:
|
149 |
+
pose_vis_img = self.draw_skeleton(
|
150 |
+
pose_vis_img, kpts, ss, kpt_thr=kpt_thr, openpose_format=openpose_format, line_width=line_width
|
151 |
+
)
|
152 |
+
except ValueError as e:
|
153 |
+
log.error(f"Error in draw_skeleton func, {e}")
|
154 |
+
|
155 |
+
return pose_vis_img
|
cosmos_transfer1/auxiliary/robot_augmentation/README.md
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Robot Data Augmentation with Cosmos-Transfer1
|
2 |
+
|
3 |
+
This pipeline provides a two-step process to augment robotic videos using **Cosmos-Transfer1-7B**. It leverages **spatial-temporal control** to modify backgrounds while preserving the shape and/or appearance of the robot foreground.
|
4 |
+
|
5 |
+
## Overview of Settings
|
6 |
+
|
7 |
+
We propose two augmentation settings:
|
8 |
+
|
9 |
+
### Setting 1 (fg_vis_edge_bg_seg): Preserve Shape and Appearance of the Robot (foreground)
|
10 |
+
- **Foreground Controls**: `Edge`, `Vis`
|
11 |
+
- **Background Controls**: `Segmentation`
|
12 |
+
- **Weights**:
|
13 |
+
- `w_edge(FG) = 1`
|
14 |
+
- `w_vis(FG) = 1`
|
15 |
+
- `w_seg(BG) = 1`
|
16 |
+
- All other weights = 0
|
17 |
+
|
18 |
+
### Setting 2 (fg_edge_bg_seg): Preserve Only Shape of the Robot (foreground)
|
19 |
+
- **Foreground Controls**: `Edge`
|
20 |
+
- **Background Controls**: `Segmentation`
|
21 |
+
- **Weights**:
|
22 |
+
- `w_edge(FG) = 1`
|
23 |
+
- `w_seg(BG) = 1`
|
24 |
+
- All other weights = 0
|
25 |
+
|
26 |
+
## Step-by-Step Instructions
|
27 |
+
|
28 |
+
### Step 1: Generate Spatial-Temporal Weights
|
29 |
+
|
30 |
+
This script extracts foreground (robot) and background information from semantic segmentation data. It processes per-frame segmentation masks and color-to-class mappings to generate spatial-temporal weight matrices for each control modality based on the selected setting.
|
31 |
+
|
32 |
+
#### Input Requirements:
|
33 |
+
- A `segmentation` folder containing per-frame segmentation masks in PNG format
|
34 |
+
- A `segmentation_label` folder containing color-to-class mapping JSON files for each frame, for example:
|
35 |
+
```json
|
36 |
+
{
|
37 |
+
"(29, 0, 0, 255)": {
|
38 |
+
"class": "gripper0_right_r_palm_vis"
|
39 |
+
},
|
40 |
+
"(31, 0, 0, 255)": {
|
41 |
+
"class": "gripper0_right_R_thumb_proximal_base_link_vis"
|
42 |
+
},
|
43 |
+
"(33, 0, 0, 255)": {
|
44 |
+
"class": "gripper0_right_R_thumb_proximal_link_vis"
|
45 |
+
}
|
46 |
+
}
|
47 |
+
```
|
48 |
+
- An input video file
|
49 |
+
|
50 |
+
Here is an example input format:
|
51 |
+
[Example input directory](https://github.com/google-deepmind/cosmos/tree/main/assets/robot_augmentation_example/example1)
|
52 |
+
|
53 |
+
#### Usage
|
54 |
+
|
55 |
+
```bash
|
56 |
+
PYTHONPATH=$(pwd) python cosmos_transfer1/auxiliary/robot_augmentation/spatial_temporal_weight.py \
|
57 |
+
--setting setting1 \
|
58 |
+
--robot-keywords world_robot gripper robot \
|
59 |
+
--input-dir assets/robot_augmentation_example \
|
60 |
+
--output-dir outputs/robot_augmentation_example
|
61 |
+
```
|
62 |
+
|
63 |
+
#### Parameters:
|
64 |
+
|
65 |
+
* `--setting`: Weight setting to use (choices: 'setting1', 'setting2', default: 'setting1')
|
66 |
+
* setting1: Emphasizes robot in visual and edge features (vis: 1.0 foreground, edge: 1.0 foreground, seg: 1.0 background)
|
67 |
+
* setting2: Emphasizes robot only in edge features (edge: 1.0 foreground, seg: 1.0 background)
|
68 |
+
|
69 |
+
* `--input-dir`: Input directory containing example folders
|
70 |
+
* Default: 'assets/robot_augmentation_example'
|
71 |
+
|
72 |
+
* `--output-dir`: Output directory for weight matrices
|
73 |
+
* Default: 'outputs/robot_augmentation_example'
|
74 |
+
|
75 |
+
* `--robot-keywords`: Keywords used to identify robot classes
|
76 |
+
* Default: ["world_robot", "gripper", "robot"]
|
77 |
+
* Any semantic class containing these keywords will be treated as robot foreground
|
78 |
+
|
79 |
+
### Step 2: Run Cosmos-Transfer1 Inference
|
80 |
+
|
81 |
+
Use the generated spatial-temporal weight matrices to perform video augmentation with the proper controls.
|
82 |
+
|
83 |
+
```bash
|
84 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}"
|
85 |
+
export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}"
|
86 |
+
export NUM_GPU="${NUM_GPU:=1}"
|
87 |
+
|
88 |
+
PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 \
|
89 |
+
cosmos_transfer1/diffusion/inference/transfer.py \
|
90 |
+
--checkpoint_dir $CHECKPOINT_DIR \
|
91 |
+
--video_save_folder outputs/robot_example_spatial_temporal_setting1 \
|
92 |
+
--controlnet_specs assets/robot_augmentation_example/example1/inference_cosmos_transfer1_robot_spatiotemporal_weights.json \
|
93 |
+
--offload_text_encoder_model \
|
94 |
+
--offload_guardrail_models \
|
95 |
+
--num_gpus $NUM_GPU
|
96 |
+
```
|
97 |
+
|
98 |
+
- Augmented videos are saved in `outputs/robot_example_spatial_temporal_setting1/`
|
99 |
+
|
100 |
+
## Input Outputs Example
|
101 |
+
|
102 |
+
Input video:
|
103 |
+
|
104 |
+
<video src="https://github.com/user-attachments/assets/9c2df99d-7d0c-4dcf-af87-4ec9f65328ed">
|
105 |
+
Your browser does not support the video tag.
|
106 |
+
</video>
|
107 |
+
|
108 |
+
You can run multiple times with different prompts (e.g., `assets/robot_augmentation_example/example1/example1_prompts.json`), and you can get different augmentation results:
|
109 |
+
|
110 |
+
<video src="https://github.com/user-attachments/assets/6dee15f5-9d8b-469a-a92a-3419cb466d44">
|
111 |
+
Your browser does not support the video tag.
|
112 |
+
</video>
|
cosmos_transfer1/auxiliary/robot_augmentation/spatial_temporal_weight.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# This script processes segmentation results for each video frame saved as JSON files and generates a spatial-temporal weight matrix saved as a .pt file.
|
17 |
+
# The input JSON files contain segmentation information for each frame, and the output .pt file represents the spatial-temporal weight matrix for the video.
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
import glob
|
21 |
+
import json
|
22 |
+
import logging
|
23 |
+
import os
|
24 |
+
import re
|
25 |
+
from collections import defaultdict
|
26 |
+
|
27 |
+
import cv2
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
from tqdm import tqdm
|
31 |
+
|
32 |
+
# Configure logging
|
33 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
34 |
+
logger = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
# Class to manage different weight settings
|
38 |
+
class WeightSettings:
|
39 |
+
"""Class to manage different weight settings for the features"""
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def get_settings(setting_name):
|
43 |
+
"""Get weight settings by name
|
44 |
+
|
45 |
+
Args:
|
46 |
+
setting_name (str): Name of the setting
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
dict: Dictionary with weights for each feature
|
50 |
+
"""
|
51 |
+
settings = {
|
52 |
+
# Default setting: Emphasize robot in all features
|
53 |
+
"fg_vis_edge_bg_seg": {
|
54 |
+
"depth": {"foreground": 0.0, "background": 0.0},
|
55 |
+
"vis": {"foreground": 1.0, "background": 0.0},
|
56 |
+
"edge": {"foreground": 1.0, "background": 0.0},
|
57 |
+
"seg": {"foreground": 0.0, "background": 1.0},
|
58 |
+
},
|
59 |
+
"fg_edge_bg_seg": {
|
60 |
+
"depth": {"foreground": 0.0, "background": 0.0},
|
61 |
+
"vis": {"foreground": 0.0, "background": 0.0},
|
62 |
+
"edge": {"foreground": 1.0, "background": 0.0},
|
63 |
+
"seg": {"foreground": 0.0, "background": 1.0},
|
64 |
+
},
|
65 |
+
}
|
66 |
+
|
67 |
+
if setting_name not in settings:
|
68 |
+
logger.warning(f"Setting '{setting_name}' not found. Using default.")
|
69 |
+
return settings["fg_vis_edge_bg_seg"]
|
70 |
+
|
71 |
+
return settings[setting_name]
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def list_settings():
|
75 |
+
"""List all available settings
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
list: List of setting names
|
79 |
+
"""
|
80 |
+
return ["fg_vis_edge_bg_seg", "fg_edge_bg_seg"]
|
81 |
+
|
82 |
+
|
83 |
+
def get_video_info(video_path):
|
84 |
+
"""Get video dimensions and frame count"""
|
85 |
+
cap = cv2.VideoCapture(video_path)
|
86 |
+
if not cap.isOpened():
|
87 |
+
raise ValueError(f"Could not open video file: {video_path}")
|
88 |
+
|
89 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
90 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
91 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
92 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
93 |
+
|
94 |
+
cap.release()
|
95 |
+
return width, height, frame_count, fps
|
96 |
+
|
97 |
+
|
98 |
+
def parse_color_key(color_key):
|
99 |
+
"""Parse a color key string into an RGB tuple
|
100 |
+
|
101 |
+
Args:
|
102 |
+
color_key (str): Color key string in the format "(r,g,b,a)" or similar
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
tuple: RGB tuple (r, g, b)
|
106 |
+
"""
|
107 |
+
# Extract numbers using regex to handle different formats
|
108 |
+
numbers = re.findall(r"\d+", color_key)
|
109 |
+
if len(numbers) >= 3:
|
110 |
+
r, g, b = map(int, numbers[:3])
|
111 |
+
return (r, g, b)
|
112 |
+
else:
|
113 |
+
raise ValueError(f"Invalid color key format: {color_key}")
|
114 |
+
|
115 |
+
|
116 |
+
def save_visualization(mask, frame_num, feature_name, viz_dir):
|
117 |
+
"""Save a visualization of the binary mask
|
118 |
+
|
119 |
+
Args:
|
120 |
+
mask (numpy.ndarray): The mask (values 0 or 255)
|
121 |
+
frame_num (int): The frame number
|
122 |
+
feature_name (str): The name of the feature (depth, vis, edge, seg)
|
123 |
+
viz_dir (str): Directory to save visualizations
|
124 |
+
"""
|
125 |
+
# Simply save the binary mask directly
|
126 |
+
output_path = os.path.join(viz_dir, f"{feature_name}_frame_{frame_num:06d}.png")
|
127 |
+
cv2.imwrite(output_path, mask)
|
128 |
+
logger.info(f"Saved binary visualization to {output_path}")
|
129 |
+
|
130 |
+
|
131 |
+
def process_segmentation_files(
|
132 |
+
segmentation_dir,
|
133 |
+
output_dir,
|
134 |
+
viz_dir,
|
135 |
+
video_path=None,
|
136 |
+
weights_dict=None,
|
137 |
+
setting_name="fg_vis_edge_bg_seg",
|
138 |
+
robot_keywords=None,
|
139 |
+
):
|
140 |
+
"""Process all segmentation JSON files and create weight matrices
|
141 |
+
|
142 |
+
Args:
|
143 |
+
segmentation_dir (str): Directory containing segmentation JSON files
|
144 |
+
output_dir (str): Directory to save weight matrices
|
145 |
+
viz_dir (str): Directory to save visualizations
|
146 |
+
video_path (str, optional): Path to the video file. Defaults to None.
|
147 |
+
weights_dict (dict, optional): Dictionary with weights for each feature.
|
148 |
+
Format: {
|
149 |
+
'depth': {'foreground': float, 'background': float},
|
150 |
+
'vis': {'foreground': float, 'background': float},
|
151 |
+
'edge': {'foreground': float, 'background': float},
|
152 |
+
'seg': {'foreground': float, 'background': float}
|
153 |
+
}
|
154 |
+
Values should be in range 0-1. Defaults to None.
|
155 |
+
setting_name (str, optional): Weight setting name. Defaults to 'fg_vis_edge_bg_seg (setting1)'.
|
156 |
+
robot_keywords (list, optional): List of keywords to identify robot classes. Defaults to ["robot"].
|
157 |
+
"""
|
158 |
+
|
159 |
+
# Set default robot keywords if not provided
|
160 |
+
if robot_keywords is None:
|
161 |
+
robot_keywords = ["robot"]
|
162 |
+
|
163 |
+
# Get all JSON files
|
164 |
+
json_files = sorted(glob.glob(os.path.join(segmentation_dir, "*.json")))
|
165 |
+
logger.info(f"Found {len(json_files)} JSON files")
|
166 |
+
|
167 |
+
if len(json_files) == 0:
|
168 |
+
raise ValueError(f"No JSON files found in {segmentation_dir}")
|
169 |
+
|
170 |
+
# For example directories, check for PNG files
|
171 |
+
png_dir = os.path.join(os.path.dirname(segmentation_dir), "segmentation")
|
172 |
+
png_files = []
|
173 |
+
if os.path.exists(png_dir):
|
174 |
+
png_files = sorted(glob.glob(os.path.join(png_dir, "*.png")))
|
175 |
+
logger.info(f"Found {len(png_files)} PNG files in segmentation directory")
|
176 |
+
|
177 |
+
# Step 1: Create a unified color-to-class mapping from all JSON files
|
178 |
+
logger.info("Creating unified color-to-class mapping...")
|
179 |
+
rgb_to_class = {}
|
180 |
+
rgb_to_is_robot = {}
|
181 |
+
|
182 |
+
for json_file in tqdm(json_files, desc="Processing JSON files for unified mapping"):
|
183 |
+
with open(json_file, "r") as f:
|
184 |
+
json_data = json.load(f)
|
185 |
+
|
186 |
+
for color_key, data in json_data.items():
|
187 |
+
color = parse_color_key(color_key)
|
188 |
+
class_name = data["class"]
|
189 |
+
|
190 |
+
# Store RGB color for matching
|
191 |
+
rgb_to_class[color] = class_name
|
192 |
+
rgb_to_is_robot[color] = any(keyword in class_name for keyword in robot_keywords)
|
193 |
+
|
194 |
+
# Print statistics about the unified color mapping
|
195 |
+
robot_colors = [color for color, is_robot in rgb_to_is_robot.items() if is_robot]
|
196 |
+
logger.info(f"Unified mapping: Found {len(robot_colors)} robot colors out of {len(rgb_to_is_robot)} total colors")
|
197 |
+
if robot_colors:
|
198 |
+
logger.info(f"Robot classes: {[rgb_to_class[color] for color in robot_colors]}")
|
199 |
+
|
200 |
+
# Convert color mapping to arrays for vectorized operations
|
201 |
+
colors = list(rgb_to_is_robot.keys())
|
202 |
+
color_array = np.array(colors)
|
203 |
+
is_robot_array = np.array([rgb_to_is_robot[color] for color in colors], dtype=bool)
|
204 |
+
|
205 |
+
# If we have PNG files, get dimensions from the first PNG
|
206 |
+
if png_files:
|
207 |
+
# Get dimensions from the first PNG file
|
208 |
+
first_png = cv2.imread(png_files[0])
|
209 |
+
if first_png is None:
|
210 |
+
raise ValueError(f"Could not read PNG file: {png_files[0]}")
|
211 |
+
|
212 |
+
height, width = first_png.shape[:2]
|
213 |
+
frame_count = len(png_files)
|
214 |
+
|
215 |
+
# Match frame numbers between JSON and PNG files to ensure correct correspondence
|
216 |
+
json_frame_nums = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in json_files]
|
217 |
+
png_frame_nums = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in png_files]
|
218 |
+
|
219 |
+
# Find common frames between JSON and PNG files
|
220 |
+
common_frames = sorted(set(json_frame_nums).intersection(set(png_frame_nums)))
|
221 |
+
logger.info(f"Found {len(common_frames)} common frames between JSON and PNG files")
|
222 |
+
|
223 |
+
if len(common_frames) == 0:
|
224 |
+
raise ValueError("No matching frames found between JSON and PNG files")
|
225 |
+
|
226 |
+
# Create maps to easily look up files by frame number
|
227 |
+
json_map = {int(os.path.basename(f).split("_")[-1].split(".")[0]): f for f in json_files}
|
228 |
+
png_map = {int(os.path.basename(f).split("_")[-1].split(".")[0]): f for f in png_files}
|
229 |
+
|
230 |
+
# Create new lists with only matching files
|
231 |
+
json_files = [json_map[frame] for frame in common_frames if frame in json_map]
|
232 |
+
png_files = [png_map[frame] for frame in common_frames if frame in png_map]
|
233 |
+
num_frames = len(json_files)
|
234 |
+
|
235 |
+
logger.info(f"Using PNG dimensions: {width}x{height}, processing {num_frames} frames")
|
236 |
+
else:
|
237 |
+
# Get video information if no PNG files available
|
238 |
+
try:
|
239 |
+
width, height, frame_count, fps = get_video_info(video_path)
|
240 |
+
logger.info(f"Video dimensions: {width}x{height}, {frame_count} frames, {fps} fps")
|
241 |
+
num_frames = min(len(json_files), frame_count)
|
242 |
+
except Exception as e:
|
243 |
+
logger.warning(f"Warning: Could not get video information: {e}")
|
244 |
+
# Use a default size if we can't get the video info
|
245 |
+
width, height = 640, 480
|
246 |
+
num_frames = len(json_files)
|
247 |
+
logger.info(f"Using default dimensions: {width}x{height}, {num_frames} frames")
|
248 |
+
|
249 |
+
# Initialize weight tensors
|
250 |
+
depth_weights = torch.zeros((num_frames, height, width))
|
251 |
+
vis_weights = torch.zeros((num_frames, height, width))
|
252 |
+
edge_weights = torch.zeros((num_frames, height, width))
|
253 |
+
seg_weights = torch.zeros((num_frames, height, width))
|
254 |
+
|
255 |
+
# Process frames
|
256 |
+
if png_files:
|
257 |
+
# Process PNG files directly
|
258 |
+
for i, (json_file, png_file) in enumerate(zip(json_files, png_files)):
|
259 |
+
# Get frame number from filename
|
260 |
+
frame_num = int(os.path.basename(json_file).split("_")[-1].split(".")[0])
|
261 |
+
|
262 |
+
# Read the corresponding PNG file
|
263 |
+
frame = cv2.imread(png_file)
|
264 |
+
|
265 |
+
if frame is None:
|
266 |
+
logger.warning(f"Warning: Could not read frame {i} from PNG. Using blank frame.")
|
267 |
+
frame = np.zeros((height, width, 3), dtype=np.uint8)
|
268 |
+
|
269 |
+
# Convert frame to RGB
|
270 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
271 |
+
|
272 |
+
# Calculate total pixels
|
273 |
+
total_pixels = height * width
|
274 |
+
|
275 |
+
# Vectorized approach for finding nearest colors
|
276 |
+
# Convert frame_rgb to a 2D array of shape (height*width, 3)
|
277 |
+
pixels = frame_rgb.reshape(-1, 3)
|
278 |
+
|
279 |
+
# Calculate distances between each pixel and each color (vectorized)
|
280 |
+
# This creates a matrix of shape (height*width, num_colors)
|
281 |
+
distances = np.sqrt(np.sum((pixels[:, np.newaxis, :] - color_array[np.newaxis, :, :]) ** 2, axis=2))
|
282 |
+
|
283 |
+
# Find the index of the nearest color for each pixel
|
284 |
+
nearest_color_indices = np.argmin(distances, axis=1)
|
285 |
+
|
286 |
+
# Get the is_robot value for each pixel based on its nearest color
|
287 |
+
pixel_is_robot = is_robot_array[nearest_color_indices]
|
288 |
+
|
289 |
+
# Reshape back to image dimensions
|
290 |
+
pixel_is_robot_2d = pixel_is_robot.reshape(height, width)
|
291 |
+
|
292 |
+
# Count robot and matched pixels
|
293 |
+
robot_pixel_count = np.sum(pixel_is_robot)
|
294 |
+
matched_pixel_count = pixels.shape[0] # All pixels are matched now
|
295 |
+
|
296 |
+
# Create masks based on the is_robot classification
|
297 |
+
depth_mask = np.where(
|
298 |
+
pixel_is_robot_2d, weights_dict["depth"]["foreground"], weights_dict["depth"]["background"]
|
299 |
+
)
|
300 |
+
|
301 |
+
vis_mask = np.where(pixel_is_robot_2d, weights_dict["vis"]["foreground"], weights_dict["vis"]["background"])
|
302 |
+
|
303 |
+
edge_mask = np.where(
|
304 |
+
pixel_is_robot_2d, weights_dict["edge"]["foreground"], weights_dict["edge"]["background"]
|
305 |
+
)
|
306 |
+
|
307 |
+
seg_mask = np.where(pixel_is_robot_2d, weights_dict["seg"]["foreground"], weights_dict["seg"]["background"])
|
308 |
+
|
309 |
+
# Create visualization mask
|
310 |
+
visualization_mask = np.zeros((height, width), dtype=np.uint8)
|
311 |
+
visualization_mask[pixel_is_robot_2d] = 255
|
312 |
+
|
313 |
+
# Log statistics
|
314 |
+
robot_percentage = (robot_pixel_count / total_pixels) * 100
|
315 |
+
matched_percentage = (matched_pixel_count / total_pixels) * 100
|
316 |
+
logger.info(f"Frame {frame_num}: {robot_pixel_count} robot pixels ({robot_percentage:.2f}%)")
|
317 |
+
logger.info(f"Frame {frame_num}: {matched_pixel_count} matched pixels ({matched_percentage:.2f}%)")
|
318 |
+
|
319 |
+
# Save visualizations for this frame
|
320 |
+
save_visualization(visualization_mask, frame_num, "segmentation", viz_dir)
|
321 |
+
|
322 |
+
# Store the masks in the weight tensors
|
323 |
+
depth_weights[i] = torch.from_numpy(depth_mask)
|
324 |
+
vis_weights[i] = torch.from_numpy(vis_mask)
|
325 |
+
edge_weights[i] = torch.from_numpy(edge_mask)
|
326 |
+
seg_weights[i] = torch.from_numpy(seg_mask)
|
327 |
+
else:
|
328 |
+
# Use video frames if available
|
329 |
+
try:
|
330 |
+
# Open the segmentation video
|
331 |
+
cap = cv2.VideoCapture(video_path)
|
332 |
+
if not cap.isOpened():
|
333 |
+
raise ValueError(f"Could not open video file: {video_path}")
|
334 |
+
|
335 |
+
# Process each frame using the unified color mapping
|
336 |
+
for i, json_file in enumerate(tqdm(json_files[:num_frames], desc="Processing frames")):
|
337 |
+
# Get frame number from filename
|
338 |
+
frame_num = int(os.path.basename(json_file).split("_")[-1].split(".")[0])
|
339 |
+
|
340 |
+
# Read the corresponding frame from the video
|
341 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
342 |
+
ret, frame = cap.read()
|
343 |
+
|
344 |
+
if not ret:
|
345 |
+
logger.warning(f"Warning: Could not read frame {i} from video. Using blank frame.")
|
346 |
+
frame = np.zeros((height, width, 3), dtype=np.uint8)
|
347 |
+
|
348 |
+
# Convert frame to RGB
|
349 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
350 |
+
|
351 |
+
# Calculate total pixels
|
352 |
+
total_pixels = height * width
|
353 |
+
|
354 |
+
# Vectorized approach for finding nearest colors
|
355 |
+
pixels = frame_rgb.reshape(-1, 3)
|
356 |
+
distances = np.sqrt(np.sum((pixels[:, np.newaxis, :] - color_array[np.newaxis, :, :]) ** 2, axis=2))
|
357 |
+
nearest_color_indices = np.argmin(distances, axis=1)
|
358 |
+
pixel_is_robot = is_robot_array[nearest_color_indices]
|
359 |
+
pixel_is_robot_2d = pixel_is_robot.reshape(height, width)
|
360 |
+
|
361 |
+
# Count robot and matched pixels
|
362 |
+
robot_pixel_count = np.sum(pixel_is_robot)
|
363 |
+
matched_pixel_count = pixels.shape[0]
|
364 |
+
|
365 |
+
# Create masks based on the is_robot classification
|
366 |
+
depth_mask = np.where(
|
367 |
+
pixel_is_robot_2d, weights_dict["depth"]["foreground"], weights_dict["depth"]["background"]
|
368 |
+
)
|
369 |
+
vis_mask = np.where(
|
370 |
+
pixel_is_robot_2d, weights_dict["vis"]["foreground"], weights_dict["vis"]["background"]
|
371 |
+
)
|
372 |
+
edge_mask = np.where(
|
373 |
+
pixel_is_robot_2d, weights_dict["edge"]["foreground"], weights_dict["edge"]["background"]
|
374 |
+
)
|
375 |
+
seg_mask = np.where(
|
376 |
+
pixel_is_robot_2d, weights_dict["seg"]["foreground"], weights_dict["seg"]["background"]
|
377 |
+
)
|
378 |
+
|
379 |
+
# Create visualization mask
|
380 |
+
visualization_mask = np.zeros((height, width), dtype=np.uint8)
|
381 |
+
visualization_mask[pixel_is_robot_2d] = 255
|
382 |
+
|
383 |
+
# Log statistics
|
384 |
+
robot_percentage = (robot_pixel_count / total_pixels) * 100
|
385 |
+
matched_percentage = (matched_pixel_count / total_pixels) * 100
|
386 |
+
logger.info(f"Frame {frame_num}: {robot_pixel_count} robot pixels ({robot_percentage:.2f}%)")
|
387 |
+
logger.info(f"Frame {frame_num}: {matched_pixel_count} matched pixels ({matched_percentage:.2f}%)")
|
388 |
+
|
389 |
+
# Save visualizations for this frame
|
390 |
+
save_visualization(visualization_mask, frame_num, "segmentation", viz_dir)
|
391 |
+
|
392 |
+
# Store the masks in the weight tensors
|
393 |
+
depth_weights[i] = torch.from_numpy(depth_mask)
|
394 |
+
vis_weights[i] = torch.from_numpy(vis_mask)
|
395 |
+
edge_weights[i] = torch.from_numpy(edge_mask)
|
396 |
+
seg_weights[i] = torch.from_numpy(seg_mask)
|
397 |
+
|
398 |
+
# Close the video capture
|
399 |
+
cap.release()
|
400 |
+
except Exception as e:
|
401 |
+
logger.warning(f"Warning: Error processing video: {e}")
|
402 |
+
logger.warning("Cannot process this example without proper frame data.")
|
403 |
+
raise ValueError(f"Cannot process example without frame data: {e}")
|
404 |
+
|
405 |
+
# Save weight tensors
|
406 |
+
# Convert weights to half precision (float16) to reduce file size
|
407 |
+
depth_weights_half = depth_weights.to(torch.float16)
|
408 |
+
vis_weights_half = vis_weights.to(torch.float16)
|
409 |
+
edge_weights_half = edge_weights.to(torch.float16)
|
410 |
+
seg_weights_half = seg_weights.to(torch.float16)
|
411 |
+
|
412 |
+
# Save the half precision tensors
|
413 |
+
torch.save(depth_weights_half, os.path.join(output_dir, "depth_weights.pt"))
|
414 |
+
torch.save(vis_weights_half, os.path.join(output_dir, "vis_weights.pt"))
|
415 |
+
torch.save(edge_weights_half, os.path.join(output_dir, "edge_weights.pt"))
|
416 |
+
torch.save(seg_weights_half, os.path.join(output_dir, "seg_weights.pt"))
|
417 |
+
|
418 |
+
logger.info(f"Saved weight matrices to {output_dir}")
|
419 |
+
logger.info(f"Weight matrix shape: {depth_weights_half.shape}, dtype: {depth_weights_half.dtype}")
|
420 |
+
logger.info(f"Saved visualizations to {viz_dir}")
|
421 |
+
|
422 |
+
return output_dir, viz_dir
|
423 |
+
|
424 |
+
|
425 |
+
def process_all_examples(input_dir, output_dir, setting_name="fg_vis_edge_bg_seg", robot_keywords=None):
|
426 |
+
"""Process all example directories in the provided input directory
|
427 |
+
|
428 |
+
Args:
|
429 |
+
input_dir (str): Input directory containing example folders
|
430 |
+
output_dir (str): Output directory for weight matrices
|
431 |
+
setting_name (str, optional): Weight setting name. Defaults to 'fg_vis_edge_bg_seg'.
|
432 |
+
robot_keywords (list, optional): List of keywords to identify robot classes. Defaults to None.
|
433 |
+
"""
|
434 |
+
# Find all example directories
|
435 |
+
if not os.path.exists(input_dir):
|
436 |
+
logger.error(f"Input directory not found: {input_dir}")
|
437 |
+
return []
|
438 |
+
|
439 |
+
# List example directories
|
440 |
+
examples = [d for d in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, d))]
|
441 |
+
examples = sorted(examples)
|
442 |
+
|
443 |
+
if not examples:
|
444 |
+
logger.warning("No example directories found.")
|
445 |
+
return []
|
446 |
+
|
447 |
+
# Print found examples
|
448 |
+
logger.info(f"Found {len(examples)} example directories:")
|
449 |
+
for example in examples:
|
450 |
+
logger.info(f" - {example}")
|
451 |
+
|
452 |
+
# Store processing results
|
453 |
+
results = []
|
454 |
+
|
455 |
+
# Process each example
|
456 |
+
for example in examples:
|
457 |
+
try:
|
458 |
+
logger.info(f"\nProcessing {example}...")
|
459 |
+
|
460 |
+
# Process this example with custom directories
|
461 |
+
out_dir, viz_dir = process_example_with_dirs(example, input_dir, output_dir, setting_name, robot_keywords)
|
462 |
+
results.append((example, out_dir, viz_dir))
|
463 |
+
|
464 |
+
logger.info(f"Results for {example} saved to:")
|
465 |
+
logger.info(f" Weight matrices: {out_dir}")
|
466 |
+
logger.info(f" Visualizations: {viz_dir}")
|
467 |
+
|
468 |
+
except Exception as e:
|
469 |
+
logger.error(f"Error processing {example}: {e}")
|
470 |
+
|
471 |
+
logger.info("\nAll examples processed.")
|
472 |
+
return results
|
473 |
+
|
474 |
+
|
475 |
+
# Process a specific example with custom input and output directories
|
476 |
+
def process_example_with_dirs(
|
477 |
+
example_name, input_dir, output_dir, setting_name="fg_vis_edge_bg_seg", robot_keywords=None
|
478 |
+
):
|
479 |
+
"""Process a specific example with custom input and output directories
|
480 |
+
|
481 |
+
Args:
|
482 |
+
example_name (str): Name of the example directory
|
483 |
+
input_dir (str): Path to input directory containing example folders
|
484 |
+
output_dir (str): Path to output directory for weight matrices
|
485 |
+
setting_name (str, optional): Weight setting name. Defaults to 'fg_vis_edge_bg_seg'.
|
486 |
+
robot_keywords (list, optional): List of keywords to identify robot classes. Defaults to None.
|
487 |
+
"""
|
488 |
+
# Create paths for this example
|
489 |
+
example_dir = os.path.join(input_dir, example_name)
|
490 |
+
segmentation_dir = os.path.join(example_dir, "segmentation_label")
|
491 |
+
video_path = os.path.join(example_dir, "segmentation.mp4")
|
492 |
+
|
493 |
+
# Create output directories
|
494 |
+
example_output_dir = os.path.join(output_dir, example_name)
|
495 |
+
viz_dir = os.path.join(example_output_dir, "visualizations")
|
496 |
+
|
497 |
+
# Check if weight files already exist
|
498 |
+
depth_weights_path = os.path.join(example_output_dir, "depth_weights.pt")
|
499 |
+
if os.path.exists(depth_weights_path):
|
500 |
+
logger.info(f"Weight files already exist for {example_name}, skipping processing")
|
501 |
+
return example_output_dir, viz_dir
|
502 |
+
|
503 |
+
# Create output directories if they don't exist
|
504 |
+
os.makedirs(example_output_dir, exist_ok=True)
|
505 |
+
os.makedirs(viz_dir, exist_ok=True)
|
506 |
+
|
507 |
+
# Get weight settings
|
508 |
+
weights_dict = WeightSettings.get_settings(setting_name)
|
509 |
+
|
510 |
+
# Process this example directly with paths
|
511 |
+
return process_segmentation_files(
|
512 |
+
segmentation_dir=segmentation_dir,
|
513 |
+
output_dir=example_output_dir,
|
514 |
+
viz_dir=viz_dir,
|
515 |
+
video_path=video_path,
|
516 |
+
weights_dict=weights_dict,
|
517 |
+
setting_name=setting_name,
|
518 |
+
robot_keywords=robot_keywords,
|
519 |
+
)
|
520 |
+
|
521 |
+
|
522 |
+
if __name__ == "__main__":
|
523 |
+
# Parse command-line arguments
|
524 |
+
parser = argparse.ArgumentParser(
|
525 |
+
description="Process segmentation files to generate spatial-temporal weight matrices"
|
526 |
+
)
|
527 |
+
parser.add_argument(
|
528 |
+
"--setting",
|
529 |
+
type=str,
|
530 |
+
default="fg_vis_edge_bg_seg",
|
531 |
+
choices=WeightSettings.list_settings(),
|
532 |
+
help="Weight setting to use (default: fg_vis_edge_bg_seg (setting1), fg_edge_bg_seg (setting2))",
|
533 |
+
)
|
534 |
+
parser.add_argument(
|
535 |
+
"--input-dir",
|
536 |
+
type=str,
|
537 |
+
default="assets/robot_augmentation_example",
|
538 |
+
help="Input directory containing example folders",
|
539 |
+
)
|
540 |
+
parser.add_argument(
|
541 |
+
"--output-dir",
|
542 |
+
type=str,
|
543 |
+
default="outputs/robot_augmentation_example",
|
544 |
+
help="Output directory for weight matrices",
|
545 |
+
)
|
546 |
+
parser.add_argument(
|
547 |
+
"--robot-keywords",
|
548 |
+
type=str,
|
549 |
+
nargs="+",
|
550 |
+
default=["world_robot", "gripper", "robot"],
|
551 |
+
help="Keywords used to identify robot classes (default: world_robot gripper robot)",
|
552 |
+
)
|
553 |
+
parser.add_argument(
|
554 |
+
"--log-level",
|
555 |
+
type=str,
|
556 |
+
default="INFO",
|
557 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
558 |
+
help="Set the logging level",
|
559 |
+
)
|
560 |
+
args = parser.parse_args()
|
561 |
+
|
562 |
+
# Set logging level from command line argument
|
563 |
+
logger.setLevel(getattr(logging, args.log_level))
|
564 |
+
|
565 |
+
# Get directories from arguments
|
566 |
+
input_dir = args.input_dir
|
567 |
+
output_dir = args.output_dir
|
568 |
+
setting_name = args.setting
|
569 |
+
robot_keywords = args.robot_keywords
|
570 |
+
|
571 |
+
logger.info(f"Using input directory: {input_dir}")
|
572 |
+
logger.info(f"Using output directory: {output_dir}")
|
573 |
+
logger.info(f"Using weight setting: {setting_name}")
|
574 |
+
logger.info(f"Using robot keywords: {robot_keywords}")
|
575 |
+
|
576 |
+
# Process all examples with the provided input and output directories
|
577 |
+
process_all_examples(input_dir, output_dir, setting_name, robot_keywords)
|
cosmos_transfer1/auxiliary/sam2/sam2_model.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import pycocotools.mask as mask_util
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from cosmos_transfer1.utils import log
|
24 |
+
|
25 |
+
sys.path.append("cosmos_transfer1/auxiliary")
|
26 |
+
|
27 |
+
import tempfile
|
28 |
+
|
29 |
+
from PIL import Image
|
30 |
+
from sam2.sam2_video_predictor import SAM2VideoPredictor
|
31 |
+
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
|
32 |
+
|
33 |
+
from cosmos_transfer1.auxiliary.sam2.sam2_utils import (
|
34 |
+
capture_fps,
|
35 |
+
convert_masks_to_frames,
|
36 |
+
generate_tensor_from_images,
|
37 |
+
video_to_frames,
|
38 |
+
write_video,
|
39 |
+
)
|
40 |
+
from cosmos_transfer1.checkpoints import GROUNDING_DINO_MODEL_CHECKPOINT, SAM2_MODEL_CHECKPOINT
|
41 |
+
|
42 |
+
|
43 |
+
def rle_encode(mask: np.ndarray) -> dict:
|
44 |
+
"""
|
45 |
+
Encode a boolean mask (of shape (T, H, W)) using the pycocotools RLE format,
|
46 |
+
matching the format of eff_segmentation.RleMaskSAMv2 (from Yotta).
|
47 |
+
|
48 |
+
The procedure is:
|
49 |
+
1. Convert the mask to a numpy array in Fortran order.
|
50 |
+
2. Reshape the array to (-1, 1) (i.e. flatten in Fortran order).
|
51 |
+
3. Call pycocotools.mask.encode on the reshaped array.
|
52 |
+
4. Return a dictionary with the encoded data and the original mask shape.
|
53 |
+
"""
|
54 |
+
mask = np.array(mask, order="F")
|
55 |
+
# Reshape the mask to (-1, 1) in Fortran order and encode it.
|
56 |
+
encoded = mask_util.encode(np.array(mask.reshape(-1, 1), order="F"))
|
57 |
+
return {"data": encoded, "mask_shape": mask.shape}
|
58 |
+
|
59 |
+
|
60 |
+
class VideoSegmentationModel:
|
61 |
+
def __init__(self, **kwargs):
|
62 |
+
"""Initialize the model and load all required components."""
|
63 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
64 |
+
|
65 |
+
# Initialize SAM2 predictor
|
66 |
+
self.sam2_predictor = SAM2VideoPredictor.from_pretrained(SAM2_MODEL_CHECKPOINT).to(self.device)
|
67 |
+
|
68 |
+
# Initialize GroundingDINO for text-based detection
|
69 |
+
self.grounding_model_name = kwargs.get("grounding_model", GROUNDING_DINO_MODEL_CHECKPOINT)
|
70 |
+
self.processor = AutoProcessor.from_pretrained(self.grounding_model_name)
|
71 |
+
self.grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(self.grounding_model_name).to(
|
72 |
+
self.device
|
73 |
+
)
|
74 |
+
|
75 |
+
def get_boxes_from_text(self, image_path, text_prompt):
|
76 |
+
"""Get bounding boxes (and labels) from a text prompt using GroundingDINO."""
|
77 |
+
image = Image.open(image_path).convert("RGB")
|
78 |
+
|
79 |
+
inputs = self.processor(images=image, text=text_prompt, return_tensors="pt").to(self.device)
|
80 |
+
|
81 |
+
with torch.no_grad():
|
82 |
+
outputs = self.grounding_model(**inputs)
|
83 |
+
|
84 |
+
# Try with initial thresholds.
|
85 |
+
results = self.processor.post_process_grounded_object_detection(
|
86 |
+
outputs,
|
87 |
+
inputs.input_ids,
|
88 |
+
box_threshold=0.15,
|
89 |
+
text_threshold=0.25,
|
90 |
+
target_sizes=[image.size[::-1]],
|
91 |
+
)
|
92 |
+
|
93 |
+
boxes = results[0]["boxes"].cpu().numpy()
|
94 |
+
scores = results[0]["scores"].cpu().numpy()
|
95 |
+
labels = results[0].get("labels", None)
|
96 |
+
if len(boxes) == 0:
|
97 |
+
print(f"No boxes detected for prompt: '{text_prompt}'. Trying with lower thresholds...")
|
98 |
+
results = self.processor.post_process_grounded_object_detection(
|
99 |
+
outputs,
|
100 |
+
inputs.input_ids,
|
101 |
+
box_threshold=0.1,
|
102 |
+
text_threshold=0.1,
|
103 |
+
target_sizes=[image.size[::-1]],
|
104 |
+
)
|
105 |
+
boxes = results[0]["boxes"].cpu().numpy()
|
106 |
+
scores = results[0]["scores"].cpu().numpy()
|
107 |
+
labels = results[0].get("labels", None)
|
108 |
+
|
109 |
+
if len(boxes) > 0:
|
110 |
+
print(f"Found {len(boxes)} boxes with scores: {scores}")
|
111 |
+
# Sort boxes by confidence score in descending order
|
112 |
+
sorted_indices = np.argsort(scores)[::-1]
|
113 |
+
boxes = boxes[sorted_indices]
|
114 |
+
scores = scores[sorted_indices]
|
115 |
+
if labels is not None:
|
116 |
+
labels = np.array(labels)[sorted_indices]
|
117 |
+
else:
|
118 |
+
print("Still no boxes detected. Consider adjusting the prompt or using box/points mode.")
|
119 |
+
|
120 |
+
return {"boxes": boxes, "labels": labels, "scores": scores}
|
121 |
+
|
122 |
+
def visualize_frame(self, frame_idx, obj_ids, masks, video_dir, frame_names, visualization_data, save_dir=None):
|
123 |
+
"""
|
124 |
+
Process a single frame: load the image, apply the segmentation mask to black out the
|
125 |
+
detected object(s), and save both the masked frame and the binary mask image.
|
126 |
+
"""
|
127 |
+
# Load the frame.
|
128 |
+
frame_path = os.path.join(video_dir, frame_names[frame_idx])
|
129 |
+
img = Image.open(frame_path).convert("RGB")
|
130 |
+
image_np = np.array(img)
|
131 |
+
|
132 |
+
# Combine masks from the detection output.
|
133 |
+
if isinstance(masks, torch.Tensor):
|
134 |
+
mask_np = (masks[0] > 0.0).cpu().numpy().astype(bool)
|
135 |
+
combined_mask = mask_np
|
136 |
+
elif isinstance(masks, dict):
|
137 |
+
first_mask = next(iter(masks.values()))
|
138 |
+
combined_mask = np.zeros_like(first_mask, dtype=bool)
|
139 |
+
for m in masks.values():
|
140 |
+
combined_mask |= m
|
141 |
+
else:
|
142 |
+
combined_mask = None
|
143 |
+
|
144 |
+
if combined_mask is not None:
|
145 |
+
combined_mask = np.squeeze(combined_mask)
|
146 |
+
|
147 |
+
# If the mask shape doesn't match the image, resize it.
|
148 |
+
if combined_mask.shape != image_np.shape[:2]:
|
149 |
+
mask_img = Image.fromarray((combined_mask.astype(np.uint8)) * 255)
|
150 |
+
mask_img = mask_img.resize((image_np.shape[1], image_np.shape[0]), resample=Image.NEAREST)
|
151 |
+
combined_mask = np.array(mask_img) > 127
|
152 |
+
|
153 |
+
# Black out the detected region.
|
154 |
+
image_np[combined_mask] = 0
|
155 |
+
|
156 |
+
mask_image = (combined_mask.astype(np.uint8)) * 255
|
157 |
+
mask_pil = Image.fromarray(mask_image)
|
158 |
+
|
159 |
+
if save_dir:
|
160 |
+
seg_frame_path = os.path.join(save_dir, f"frame_{frame_idx}_segmented.png")
|
161 |
+
seg_pil = Image.fromarray(image_np)
|
162 |
+
seg_pil.save(seg_frame_path)
|
163 |
+
if combined_mask is not None:
|
164 |
+
mask_save_path = os.path.join(save_dir, f"frame_{frame_idx}_mask.png")
|
165 |
+
mask_pil.save(mask_save_path)
|
166 |
+
|
167 |
+
def sample(self, **kwargs):
|
168 |
+
"""
|
169 |
+
Main sampling function for video segmentation.
|
170 |
+
Returns a list of detections in which each detection contains a phrase and
|
171 |
+
an RLE-encoded segmentation mask (matching the output of the Grounded SAM model).
|
172 |
+
"""
|
173 |
+
video_dir = kwargs.get("video_dir", "")
|
174 |
+
mode = kwargs.get("mode", "points")
|
175 |
+
input_data = kwargs.get("input_data", None)
|
176 |
+
save_dir = kwargs.get("save_dir", None)
|
177 |
+
visualize = kwargs.get("visualize", False)
|
178 |
+
|
179 |
+
# Get frame names (expecting frames named as numbers with .jpg/.jpeg extension).
|
180 |
+
frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg"]]
|
181 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
182 |
+
|
183 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
184 |
+
state = self.sam2_predictor.init_state(video_path=video_dir)
|
185 |
+
|
186 |
+
ann_frame_idx = 0
|
187 |
+
ann_obj_id = 1
|
188 |
+
boxes = None
|
189 |
+
points = None
|
190 |
+
labels = None
|
191 |
+
box = None
|
192 |
+
|
193 |
+
visualization_data = {"mode": mode, "points": None, "labels": None, "box": None, "boxes": None}
|
194 |
+
|
195 |
+
if input_data is not None:
|
196 |
+
if mode == "points":
|
197 |
+
points = input_data.get("points")
|
198 |
+
labels = input_data.get("labels")
|
199 |
+
frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box(
|
200 |
+
inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels
|
201 |
+
)
|
202 |
+
visualization_data["points"] = points
|
203 |
+
visualization_data["labels"] = labels
|
204 |
+
elif mode == "box":
|
205 |
+
box = input_data.get("box")
|
206 |
+
frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box(
|
207 |
+
inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=box
|
208 |
+
)
|
209 |
+
visualization_data["box"] = box
|
210 |
+
elif mode == "prompt":
|
211 |
+
text = input_data.get("text")
|
212 |
+
first_frame_path = os.path.join(video_dir, frame_names[0])
|
213 |
+
gd_results = self.get_boxes_from_text(first_frame_path, text)
|
214 |
+
boxes = gd_results["boxes"]
|
215 |
+
labels_out = gd_results["labels"]
|
216 |
+
scores = gd_results["scores"]
|
217 |
+
log.info(f"scores: {scores}")
|
218 |
+
if len(boxes) > 0:
|
219 |
+
legacy_mask = kwargs.get("legacy_mask", False)
|
220 |
+
if legacy_mask:
|
221 |
+
# Use only the highest confidence box for legacy mask
|
222 |
+
log.info(f"using legacy_mask: {legacy_mask}")
|
223 |
+
frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box(
|
224 |
+
inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=boxes[0]
|
225 |
+
)
|
226 |
+
# Update boxes and labels after processing
|
227 |
+
boxes = boxes[:1]
|
228 |
+
if labels_out is not None:
|
229 |
+
labels_out = labels_out[:1]
|
230 |
+
else:
|
231 |
+
log.info(f"using new_mask: {legacy_mask}")
|
232 |
+
for object_id, (box, label) in enumerate(zip(boxes, labels_out)):
|
233 |
+
frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box(
|
234 |
+
inference_state=state, frame_idx=ann_frame_idx, obj_id=object_id, box=box
|
235 |
+
)
|
236 |
+
visualization_data["boxes"] = boxes
|
237 |
+
self.grounding_labels = [str(lbl) for lbl in labels_out] if labels_out is not None else [text]
|
238 |
+
else:
|
239 |
+
print("No boxes detected. Exiting.")
|
240 |
+
return [] # Return empty list if no detections
|
241 |
+
|
242 |
+
if visualize:
|
243 |
+
self.visualize_frame(
|
244 |
+
frame_idx=ann_frame_idx,
|
245 |
+
obj_ids=obj_ids,
|
246 |
+
masks=masks,
|
247 |
+
video_dir=video_dir,
|
248 |
+
frame_names=frame_names,
|
249 |
+
visualization_data=visualization_data,
|
250 |
+
save_dir=save_dir,
|
251 |
+
)
|
252 |
+
|
253 |
+
video_segments = {} # keys: frame index, values: {obj_id: mask}
|
254 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_predictor.propagate_in_video(state):
|
255 |
+
video_segments[out_frame_idx] = {
|
256 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids)
|
257 |
+
}
|
258 |
+
|
259 |
+
# For propagated frames, visualization_data is not used.
|
260 |
+
if visualize:
|
261 |
+
propagate_visualization_data = {
|
262 |
+
"mode": mode,
|
263 |
+
"points": None,
|
264 |
+
"labels": None,
|
265 |
+
"box": None,
|
266 |
+
"boxes": None,
|
267 |
+
}
|
268 |
+
self.visualize_frame(
|
269 |
+
frame_idx=out_frame_idx,
|
270 |
+
obj_ids=out_obj_ids,
|
271 |
+
masks=video_segments[out_frame_idx],
|
272 |
+
video_dir=video_dir,
|
273 |
+
frame_names=frame_names,
|
274 |
+
visualization_data=propagate_visualization_data,
|
275 |
+
save_dir=save_dir,
|
276 |
+
)
|
277 |
+
|
278 |
+
# --- Post-process video_segments to produce a list of detections ---
|
279 |
+
if len(video_segments) == 0:
|
280 |
+
return []
|
281 |
+
|
282 |
+
first_frame_path = os.path.join(video_dir, frame_names[0])
|
283 |
+
first_frame = np.array(Image.open(first_frame_path).convert("RGB"))
|
284 |
+
original_shape = first_frame.shape[:2] # (height, width)
|
285 |
+
|
286 |
+
object_masks = {} # key: obj_id, value: list of 2D boolean masks
|
287 |
+
sorted_frame_indices = sorted(video_segments.keys())
|
288 |
+
for frame_idx in sorted_frame_indices:
|
289 |
+
segments = video_segments[frame_idx]
|
290 |
+
for obj_id, mask in segments.items():
|
291 |
+
mask = np.squeeze(mask)
|
292 |
+
if mask.ndim != 2:
|
293 |
+
print(f"Warning: Unexpected mask shape {mask.shape} for object {obj_id} in frame {frame_idx}.")
|
294 |
+
continue
|
295 |
+
|
296 |
+
if mask.shape != original_shape:
|
297 |
+
mask_img = Image.fromarray(mask.astype(np.uint8) * 255)
|
298 |
+
mask_img = mask_img.resize((original_shape[1], original_shape[0]), resample=Image.NEAREST)
|
299 |
+
mask = np.array(mask_img) > 127
|
300 |
+
|
301 |
+
if obj_id not in object_masks:
|
302 |
+
object_masks[obj_id] = []
|
303 |
+
object_masks[obj_id].append(mask)
|
304 |
+
|
305 |
+
detections = []
|
306 |
+
for obj_id, mask_list in object_masks.items():
|
307 |
+
mask_stack = np.stack(mask_list, axis=0) # shape: (T, H, W)
|
308 |
+
# Use our new rle_encode (which now follows the eff_segmentation.RleMaskSAMv2 format)
|
309 |
+
rle = rle_encode(mask_stack)
|
310 |
+
if mode == "prompt" and hasattr(self, "grounding_labels"):
|
311 |
+
phrase = self.grounding_labels[0]
|
312 |
+
else:
|
313 |
+
phrase = input_data.get("text", "")
|
314 |
+
detection = {"phrase": phrase, "segmentation_mask_rle": rle}
|
315 |
+
detections.append(detection)
|
316 |
+
|
317 |
+
return detections
|
318 |
+
|
319 |
+
@staticmethod
|
320 |
+
def parse_points(points_str):
|
321 |
+
"""Parse a string of points into a numpy array.
|
322 |
+
Supports a single point ('200,300') or multiple points separated by ';' (e.g., '200,300;100,150').
|
323 |
+
"""
|
324 |
+
points = []
|
325 |
+
for point in points_str.split(";"):
|
326 |
+
coords = point.split(",")
|
327 |
+
if len(coords) != 2:
|
328 |
+
continue
|
329 |
+
points.append([float(coords[0]), float(coords[1])])
|
330 |
+
return np.array(points, dtype=np.float32)
|
331 |
+
|
332 |
+
@staticmethod
|
333 |
+
def parse_labels(labels_str):
|
334 |
+
"""Parse a comma-separated string of labels into a numpy array."""
|
335 |
+
return np.array([int(x) for x in labels_str.split(",")], dtype=np.int32)
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def parse_box(box_str):
|
339 |
+
"""Parse a comma-separated string of 4 box coordinates into a numpy array."""
|
340 |
+
return np.array([float(x) for x in box_str.split(",")], dtype=np.float32)
|
341 |
+
|
342 |
+
def __call__(
|
343 |
+
self,
|
344 |
+
input_video,
|
345 |
+
output_video=None,
|
346 |
+
output_tensor=None,
|
347 |
+
prompt=None,
|
348 |
+
box=None,
|
349 |
+
points=None,
|
350 |
+
labels=None,
|
351 |
+
weight_scaler=None,
|
352 |
+
binarize_video=False,
|
353 |
+
legacy_mask=False,
|
354 |
+
):
|
355 |
+
log.info(
|
356 |
+
f"Processing video: {input_video} to generate segmentation video: {output_video} segmentation tensor: {output_tensor}"
|
357 |
+
)
|
358 |
+
assert os.path.exists(input_video)
|
359 |
+
|
360 |
+
# Prepare input data based on the selected mode.
|
361 |
+
if points is not None:
|
362 |
+
mode = "points"
|
363 |
+
input_data = {"points": self.parse_points(points), "labels": self.parse_labels(labels)}
|
364 |
+
elif box is not None:
|
365 |
+
mode = "box"
|
366 |
+
input_data = {"box": self.parse_box(box)}
|
367 |
+
elif prompt is not None:
|
368 |
+
mode = "prompt"
|
369 |
+
input_data = {"text": prompt}
|
370 |
+
|
371 |
+
with tempfile.TemporaryDirectory() as temp_input_dir:
|
372 |
+
fps = capture_fps(input_video)
|
373 |
+
video_to_frames(input_video, temp_input_dir)
|
374 |
+
with tempfile.TemporaryDirectory() as temp_output_dir:
|
375 |
+
masks = self.sample(
|
376 |
+
video_dir=temp_input_dir,
|
377 |
+
mode=mode,
|
378 |
+
input_data=input_data,
|
379 |
+
save_dir=str(temp_output_dir),
|
380 |
+
visualize=True,
|
381 |
+
legacy_mask=legacy_mask,
|
382 |
+
)
|
383 |
+
if output_video:
|
384 |
+
os.makedirs(os.path.dirname(output_video), exist_ok=True)
|
385 |
+
frames = convert_masks_to_frames(masks)
|
386 |
+
if binarize_video:
|
387 |
+
frames = np.any(frames > 0, axis=-1).astype(np.uint8) * 255
|
388 |
+
write_video(frames, output_video, fps)
|
389 |
+
if output_tensor:
|
390 |
+
generate_tensor_from_images(
|
391 |
+
temp_output_dir, output_tensor, fps, "mask", weight_scaler=weight_scaler
|
392 |
+
)
|
cosmos_transfer1/auxiliary/sam2/sam2_pipeline.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import tempfile
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
from cosmos_transfer1.auxiliary.sam2.sam2_model import VideoSegmentationModel
|
22 |
+
from cosmos_transfer1.auxiliary.sam2.sam2_utils import (
|
23 |
+
capture_fps,
|
24 |
+
generate_tensor_from_images,
|
25 |
+
generate_video_from_images,
|
26 |
+
video_to_frames,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def parse_args():
|
31 |
+
parser = argparse.ArgumentParser(description="Video Segmentation using SAM2")
|
32 |
+
parser.add_argument("--input_video", type=str, required=True, help="Path to input video file")
|
33 |
+
parser.add_argument(
|
34 |
+
"--output_video", type=str, default="./outputs/output_video.mp4", help="Path to save the output video"
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--output_tensor", type=str, default="./outputs/output_tensor.pt", help="Path to save the output tensor"
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--mode", type=str, choices=["points", "box", "prompt"], default="points", help="Segmentation mode"
|
41 |
+
)
|
42 |
+
parser.add_argument("--prompt", type=str, help="Text prompt for prompt mode")
|
43 |
+
parser.add_argument(
|
44 |
+
"--grounding_model_path",
|
45 |
+
type=str,
|
46 |
+
default="IDEA-Research/grounding-dino-tiny",
|
47 |
+
help="Local directory for GroundingDINO model files",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--points",
|
51 |
+
type=str,
|
52 |
+
default="200,300",
|
53 |
+
help="Comma-separated point coordinates for points mode (e.g., '200,300' or for multiple points use ';' as a separator, e.g., '200,300;100,150').",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--labels",
|
57 |
+
type=str,
|
58 |
+
default="1",
|
59 |
+
help="Comma-separated labels for points mode (e.g., '1' or '1,0' for multiple points).",
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--box",
|
63 |
+
type=str,
|
64 |
+
default="300,0,500,400",
|
65 |
+
help="Comma-separated box coordinates for box mode (e.g., '300,0,500,400').",
|
66 |
+
)
|
67 |
+
# New flag to control visualization.
|
68 |
+
parser.add_argument("--visualize", action="store_true", help="If set, visualize segmentation frames (save images)")
|
69 |
+
return parser.parse_args()
|
70 |
+
|
71 |
+
|
72 |
+
def parse_points(points_str):
|
73 |
+
"""Parse a string of points into a numpy array.
|
74 |
+
Supports a single point ('200,300') or multiple points separated by ';' (e.g., '200,300;100,150').
|
75 |
+
"""
|
76 |
+
points = []
|
77 |
+
for point in points_str.split(";"):
|
78 |
+
coords = point.split(",")
|
79 |
+
if len(coords) != 2:
|
80 |
+
continue
|
81 |
+
points.append([float(coords[0]), float(coords[1])])
|
82 |
+
return np.array(points, dtype=np.float32)
|
83 |
+
|
84 |
+
|
85 |
+
def parse_labels(labels_str):
|
86 |
+
"""Parse a comma-separated string of labels into a numpy array."""
|
87 |
+
return np.array([int(x) for x in labels_str.split(",")], dtype=np.int32)
|
88 |
+
|
89 |
+
|
90 |
+
def parse_box(box_str):
|
91 |
+
"""Parse a comma-separated string of 4 box coordinates into a numpy array."""
|
92 |
+
return np.array([float(x) for x in box_str.split(",")], dtype=np.float32)
|
93 |
+
|
94 |
+
|
95 |
+
def main():
|
96 |
+
args = parse_args()
|
97 |
+
|
98 |
+
# Initialize the segmentation model.
|
99 |
+
model = VideoSegmentationModel(**vars(args))
|
100 |
+
|
101 |
+
# Prepare input data based on the selected mode.
|
102 |
+
if args.mode == "points":
|
103 |
+
input_data = {"points": parse_points(args.points), "labels": parse_labels(args.labels)}
|
104 |
+
elif args.mode == "box":
|
105 |
+
input_data = {"box": parse_box(args.box)}
|
106 |
+
elif args.mode == "prompt":
|
107 |
+
input_data = {"text": args.prompt}
|
108 |
+
|
109 |
+
with tempfile.TemporaryDirectory() as temp_input_dir:
|
110 |
+
fps = capture_fps(args.input_video)
|
111 |
+
video_to_frames(args.input_video, temp_input_dir)
|
112 |
+
with tempfile.TemporaryDirectory() as temp_output_dir:
|
113 |
+
model.sample(
|
114 |
+
video_dir=temp_input_dir,
|
115 |
+
mode=args.mode,
|
116 |
+
input_data=input_data,
|
117 |
+
save_dir=str(temp_output_dir),
|
118 |
+
visualize=True,
|
119 |
+
)
|
120 |
+
generate_video_from_images(temp_output_dir, args.output_video, fps)
|
121 |
+
generate_tensor_from_images(temp_output_dir, args.output_tensor, fps, "mask")
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
print("Starting video segmentation...")
|
126 |
+
main()
|
cosmos_transfer1/auxiliary/sam2/sam2_utils.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import tempfile
|
18 |
+
import time
|
19 |
+
|
20 |
+
import cv2
|
21 |
+
import imageio
|
22 |
+
import numpy as np
|
23 |
+
import pycocotools.mask
|
24 |
+
import torch
|
25 |
+
from natsort import natsorted
|
26 |
+
from PIL import Image
|
27 |
+
from torchvision import transforms
|
28 |
+
|
29 |
+
from cosmos_transfer1.diffusion.datasets.augmentors.control_input import (
|
30 |
+
decode_partial_rle_width1,
|
31 |
+
segmentation_color_mask,
|
32 |
+
)
|
33 |
+
from cosmos_transfer1.utils import log
|
34 |
+
|
35 |
+
|
36 |
+
def write_video(frames, output_path, fps=30):
|
37 |
+
"""
|
38 |
+
expects a sequence of [H, W, 3] or [H, W] frames
|
39 |
+
"""
|
40 |
+
with imageio.get_writer(output_path, fps=fps, macro_block_size=8) as writer:
|
41 |
+
for frame in frames:
|
42 |
+
if len(frame.shape) == 2: # single channel
|
43 |
+
frame = frame[:, :, None].repeat(3, axis=2)
|
44 |
+
writer.append_data(frame)
|
45 |
+
|
46 |
+
|
47 |
+
def capture_fps(input_video_path: str):
|
48 |
+
cap = cv2.VideoCapture(input_video_path)
|
49 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
50 |
+
return fps
|
51 |
+
|
52 |
+
|
53 |
+
def video_to_frames(input_loc, output_loc):
|
54 |
+
"""Function to extract frames from input video file
|
55 |
+
and save them as separate frames in an output directory.
|
56 |
+
Args:
|
57 |
+
input_loc: Input video file.
|
58 |
+
output_loc: Output directory to save the frames.
|
59 |
+
Returns:
|
60 |
+
None
|
61 |
+
"""
|
62 |
+
try:
|
63 |
+
os.mkdir(output_loc)
|
64 |
+
except OSError:
|
65 |
+
pass
|
66 |
+
# Log the time
|
67 |
+
time_start = time.time()
|
68 |
+
# Start capturing the feed
|
69 |
+
cap = cv2.VideoCapture(input_loc)
|
70 |
+
# Find the number of frames
|
71 |
+
video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
72 |
+
print(f"Number of frames: {video_length}")
|
73 |
+
count = 0
|
74 |
+
print("Converting video..\n")
|
75 |
+
# Start converting the video
|
76 |
+
while cap.isOpened():
|
77 |
+
# Extract the frame
|
78 |
+
ret, frame = cap.read()
|
79 |
+
if not ret:
|
80 |
+
continue
|
81 |
+
# Write the results back to output location.
|
82 |
+
cv2.imwrite(output_loc + "/%#05d.jpg" % (count + 1), frame)
|
83 |
+
count = count + 1
|
84 |
+
# If there are no more frames left
|
85 |
+
if count > (video_length - 1):
|
86 |
+
# Log the time again
|
87 |
+
time_end = time.time()
|
88 |
+
# Release the feed
|
89 |
+
cap.release()
|
90 |
+
# Print stats
|
91 |
+
print("Done extracting frames.\n%d frames extracted" % count)
|
92 |
+
print("It took %d seconds forconversion." % (time_end - time_start))
|
93 |
+
break
|
94 |
+
|
95 |
+
|
96 |
+
# Function to generate video
|
97 |
+
def convert_masks_to_frames(masks: list, num_masks_max: int = 100):
|
98 |
+
T, H, W = shape = masks[0]["segmentation_mask_rle"]["mask_shape"]
|
99 |
+
frame_start, frame_end = 0, T
|
100 |
+
num_masks = min(num_masks_max, len(masks))
|
101 |
+
mask_ids_select = np.arange(num_masks).tolist()
|
102 |
+
|
103 |
+
all_masks = np.zeros((num_masks, T, H, W), dtype=np.uint8)
|
104 |
+
for idx, mid in enumerate(mask_ids_select):
|
105 |
+
mask = masks[mid]
|
106 |
+
num_byte_per_mb = 1024 * 1024
|
107 |
+
# total number of elements in uint8 (1 byte) / num_byte_per_mb
|
108 |
+
if shape[0] * shape[1] * shape[2] / num_byte_per_mb > 256:
|
109 |
+
rle = decode_partial_rle_width1(
|
110 |
+
mask["segmentation_mask_rle"]["data"],
|
111 |
+
frame_start * shape[1] * shape[2],
|
112 |
+
frame_end * shape[1] * shape[2],
|
113 |
+
)
|
114 |
+
partial_shape = (frame_end - frame_start, shape[1], shape[2])
|
115 |
+
rle = rle.reshape(partial_shape) * 255
|
116 |
+
else:
|
117 |
+
rle = pycocotools.mask.decode(mask["segmentation_mask_rle"]["data"])
|
118 |
+
rle = rle.reshape(shape) * 255
|
119 |
+
# Select the frames that are in the video
|
120 |
+
frame_indices = np.arange(frame_start, frame_end).tolist()
|
121 |
+
rle = np.stack([rle[i] for i in frame_indices])
|
122 |
+
all_masks[idx] = rle
|
123 |
+
del rle
|
124 |
+
|
125 |
+
all_masks = segmentation_color_mask(all_masks) # NTHW -> 3THW
|
126 |
+
all_masks = all_masks.transpose(1, 2, 3, 0)
|
127 |
+
return all_masks
|
128 |
+
|
129 |
+
|
130 |
+
def generate_video_from_images(masks: list, output_file_path: str, fps, num_masks_max: int = 100):
|
131 |
+
all_masks = convert_masks_to_frames(masks, num_masks_max)
|
132 |
+
write_video(all_masks, output_file_path, fps)
|
133 |
+
print("Video generated successfully!")
|
134 |
+
|
135 |
+
|
136 |
+
def generate_tensor_from_images(
|
137 |
+
image_path_str: str, output_file_path: str, fps, search_pattern: str = None, weight_scaler: float = None
|
138 |
+
):
|
139 |
+
images = list()
|
140 |
+
image_path = os.path.abspath(image_path_str)
|
141 |
+
if search_pattern is None:
|
142 |
+
images = [img for img in natsorted(os.listdir(image_path))]
|
143 |
+
else:
|
144 |
+
for img in natsorted(os.listdir(image_path)):
|
145 |
+
if img.__contains__(search_pattern):
|
146 |
+
images.append(img)
|
147 |
+
|
148 |
+
transform = transforms.ToTensor()
|
149 |
+
image_tensors = list()
|
150 |
+
for image in images:
|
151 |
+
img_tensor = transform(Image.open(os.path.join(image_path, image)))
|
152 |
+
image_tensors.append(img_tensor.squeeze(0))
|
153 |
+
|
154 |
+
tensor = torch.stack(image_tensors) # [T, H, W], binary values, float
|
155 |
+
|
156 |
+
if weight_scaler is not None:
|
157 |
+
log.info(f"scaling the tensor by the specified scale: {weight_scaler}")
|
158 |
+
tensor = tensor * weight_scaler
|
159 |
+
|
160 |
+
log.info(f"saving tensor shape: {tensor.shape} to {output_file_path}")
|
161 |
+
torch.save(tensor, output_file_path)
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
input_loc = "cosmos_transfer1/models/sam2/assets/input_video.mp4"
|
166 |
+
output_loc = os.path.abspath(tempfile.TemporaryDirectory().name)
|
167 |
+
print(f"output_loc --- {output_loc}")
|
168 |
+
video_to_frames(input_loc, output_loc)
|
cosmos_transfer1/auxiliary/tokenizer/inference/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_transfer1/auxiliary/tokenizer/inference/image_cli.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""A CLI to run ImageTokenizer on plain images based on torch.jit.
|
17 |
+
|
18 |
+
Usage:
|
19 |
+
python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.image_cli \
|
20 |
+
--image_pattern 'path/to/input/folder/*.jpg' \
|
21 |
+
--output_dir ./reconstructions \
|
22 |
+
--checkpoint_enc ./checkpoints/<model-name>/encoder.jit \
|
23 |
+
--checkpoint_dec ./checkpoints/<model-name>/decoder.jit
|
24 |
+
|
25 |
+
Optionally, you can run the model in pure PyTorch mode:
|
26 |
+
python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.image_cli \
|
27 |
+
--image_pattern 'path/to/input/folder/*.jpg' \
|
28 |
+
--mode torch \
|
29 |
+
--tokenizer_type CI \
|
30 |
+
--spatial_compression 8 \
|
31 |
+
--checkpoint_enc ./checkpoints/<model-name>/encoder.jit \
|
32 |
+
--checkpoint_dec ./checkpoints/<model-name>/decoder.jit
|
33 |
+
"""
|
34 |
+
|
35 |
+
import os
|
36 |
+
import sys
|
37 |
+
from argparse import ArgumentParser, Namespace
|
38 |
+
from typing import Any
|
39 |
+
|
40 |
+
import numpy as np
|
41 |
+
from loguru import logger as logging
|
42 |
+
|
43 |
+
from cosmos_transfer1.auxiliary.tokenizer.inference.image_lib import ImageTokenizer
|
44 |
+
from cosmos_transfer1.auxiliary.tokenizer.inference.utils import (
|
45 |
+
get_filepaths,
|
46 |
+
get_output_filepath,
|
47 |
+
read_image,
|
48 |
+
resize_image,
|
49 |
+
write_image,
|
50 |
+
)
|
51 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks import TokenizerConfigs
|
52 |
+
|
53 |
+
|
54 |
+
def _parse_args() -> tuple[Namespace, dict[str, Any]]:
|
55 |
+
parser = ArgumentParser(description="A CLI for running ImageTokenizer on plain images.")
|
56 |
+
parser.add_argument(
|
57 |
+
"--image_pattern",
|
58 |
+
type=str,
|
59 |
+
default="path/to/images/*.jpg",
|
60 |
+
help="Glob pattern.",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--checkpoint",
|
64 |
+
type=str,
|
65 |
+
default=None,
|
66 |
+
help="JIT full Autoencoder model filepath.",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--checkpoint_enc",
|
70 |
+
type=str,
|
71 |
+
default=None,
|
72 |
+
help="JIT Encoder model filepath.",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--checkpoint_dec",
|
76 |
+
type=str,
|
77 |
+
default=None,
|
78 |
+
help="JIT Decoder model filepath.",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--tokenizer_type",
|
82 |
+
type=str,
|
83 |
+
choices=["CI", "DI"],
|
84 |
+
help="Specifies the tokenizer type.",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--spatial_compression",
|
88 |
+
type=int,
|
89 |
+
choices=[8, 16],
|
90 |
+
default=8,
|
91 |
+
help="The spatial compression factor.",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--mode",
|
95 |
+
type=str,
|
96 |
+
choices=["torch", "jit"],
|
97 |
+
default="jit",
|
98 |
+
help="Specify the backend: native 'torch' or 'jit' (default: 'jit')",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--short_size",
|
102 |
+
type=int,
|
103 |
+
default=None,
|
104 |
+
help="The size to resample inputs. None, by default.",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--dtype",
|
108 |
+
type=str,
|
109 |
+
default="bfloat16",
|
110 |
+
help="Sets the precision. Default bfloat16.",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--device",
|
114 |
+
type=str,
|
115 |
+
default="cuda",
|
116 |
+
help="Device for invoking the model.",
|
117 |
+
)
|
118 |
+
parser.add_argument("--output_dir", type=str, default=None, help="Output directory.")
|
119 |
+
parser.add_argument(
|
120 |
+
"--save_input",
|
121 |
+
action="store_true",
|
122 |
+
help="If on, the input image will be be outputed too.",
|
123 |
+
)
|
124 |
+
args = parser.parse_args()
|
125 |
+
return args
|
126 |
+
|
127 |
+
|
128 |
+
logging.info("Initializes args ...")
|
129 |
+
args = _parse_args()
|
130 |
+
if args.mode == "torch" and args.tokenizer_type not in ["CI", "DI"]:
|
131 |
+
logging.error("'torch' backend requires the tokenizer_type of 'CI' or 'DI'.")
|
132 |
+
sys.exit(1)
|
133 |
+
|
134 |
+
|
135 |
+
def _run_eval() -> None:
|
136 |
+
"""Invokes the evaluation pipeline."""
|
137 |
+
|
138 |
+
if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None:
|
139 |
+
logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.")
|
140 |
+
return
|
141 |
+
|
142 |
+
if args.mode == "torch":
|
143 |
+
tokenizer_config = TokenizerConfigs[args.tokenizer_type].value
|
144 |
+
tokenizer_config.update(dict(spatial_compression=args.spatial_compression))
|
145 |
+
else:
|
146 |
+
tokenizer_config = None
|
147 |
+
|
148 |
+
logging.info(
|
149 |
+
f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..."
|
150 |
+
)
|
151 |
+
autoencoder = ImageTokenizer(
|
152 |
+
checkpoint=args.checkpoint,
|
153 |
+
checkpoint_enc=args.checkpoint_enc,
|
154 |
+
checkpoint_dec=args.checkpoint_dec,
|
155 |
+
tokenizer_config=tokenizer_config,
|
156 |
+
device=args.device,
|
157 |
+
dtype=args.dtype,
|
158 |
+
)
|
159 |
+
|
160 |
+
filepaths = get_filepaths(args.image_pattern)
|
161 |
+
logging.info(f"Found {len(filepaths)} images from {args.image_pattern}.")
|
162 |
+
|
163 |
+
for filepath in filepaths:
|
164 |
+
logging.info(f"Reading image {filepath} ...")
|
165 |
+
image = read_image(filepath)
|
166 |
+
image = resize_image(image, short_size=args.short_size)
|
167 |
+
batch_image = np.expand_dims(image, axis=0)
|
168 |
+
|
169 |
+
logging.info("Invoking the autoencoder model in ... ")
|
170 |
+
output_image = autoencoder(batch_image)[0]
|
171 |
+
|
172 |
+
output_filepath = get_output_filepath(filepath, output_dir=args.output_dir)
|
173 |
+
logging.info(f"Outputing {output_filepath} ...")
|
174 |
+
write_image(output_filepath, output_image)
|
175 |
+
|
176 |
+
if args.save_input:
|
177 |
+
ext = os.path.splitext(output_filepath)[-1]
|
178 |
+
input_filepath = output_filepath.replace(ext, "_input" + ext)
|
179 |
+
write_image(input_filepath, image)
|
180 |
+
|
181 |
+
|
182 |
+
@logging.catch(reraise=True)
|
183 |
+
def main() -> None:
|
184 |
+
_run_eval()
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
main()
|
cosmos_transfer1/auxiliary/tokenizer/inference/image_lib.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""A library for image tokenizers inference."""
|
17 |
+
|
18 |
+
from typing import Any
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from cosmos_transfer1.auxiliary.tokenizer.inference.utils import (
|
24 |
+
load_decoder_model,
|
25 |
+
load_encoder_model,
|
26 |
+
load_model,
|
27 |
+
numpy2tensor,
|
28 |
+
pad_image_batch,
|
29 |
+
tensor2numpy,
|
30 |
+
unpad_image_batch,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class ImageTokenizer(torch.nn.Module):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
checkpoint: str = None,
|
38 |
+
checkpoint_enc: str = None,
|
39 |
+
checkpoint_dec: str = None,
|
40 |
+
tokenizer_config: dict[str, Any] = None,
|
41 |
+
device: str = "cuda",
|
42 |
+
dtype: str = "bfloat16",
|
43 |
+
) -> None:
|
44 |
+
super().__init__()
|
45 |
+
self._device = device
|
46 |
+
self._dtype = getattr(torch, dtype)
|
47 |
+
self._full_model = (
|
48 |
+
load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None
|
49 |
+
)
|
50 |
+
self._enc_model = (
|
51 |
+
load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype)
|
52 |
+
if checkpoint_enc is not None
|
53 |
+
else None
|
54 |
+
)
|
55 |
+
self._dec_model = (
|
56 |
+
load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype)
|
57 |
+
if checkpoint_dec is not None
|
58 |
+
else None
|
59 |
+
)
|
60 |
+
|
61 |
+
@torch.no_grad()
|
62 |
+
def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
63 |
+
"""Reconstrcuts a batch of image tensors after embedding into a latent.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
input_tensor: The input image Bx3xHxW layout, range [-1..1].
|
67 |
+
Returns:
|
68 |
+
The reconstructed tensor, layout Bx3xHxW, range [-1..1].
|
69 |
+
"""
|
70 |
+
if self._full_model is not None:
|
71 |
+
output_tensor = self._full_model(input_tensor)
|
72 |
+
output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor
|
73 |
+
else:
|
74 |
+
output_latent = self.encode(input_tensor)[0]
|
75 |
+
output_tensor = self.decode(output_latent)
|
76 |
+
return output_tensor
|
77 |
+
|
78 |
+
@torch.no_grad()
|
79 |
+
def decode(self, input_latent: torch.Tensor) -> torch.Tensor:
|
80 |
+
"""Decodes an image from a provided latent embedding.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
input_latent: The continuous latent Bx16xhxw for CI,
|
84 |
+
or the discrete indices Bxhxw for DI.
|
85 |
+
Returns:
|
86 |
+
The output tensor in Bx3xHxW, range [-1..1].
|
87 |
+
"""
|
88 |
+
return self._dec_model(input_latent)
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
|
92 |
+
"""Encodes an image into a latent embedding or code.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
input_tensor: The input tensor Bx3xHxW layout, range [-1..1].
|
96 |
+
Returns:
|
97 |
+
For continuous image (CI) tokenizer, the tuple contains:
|
98 |
+
- The latent embedding, Bx16x(h)x(w), where the compression
|
99 |
+
rate is (H/h x W/w), and channel dimension of 16.
|
100 |
+
For discrete image (DI) tokenizer, the tuple contains:
|
101 |
+
- The indices, Bx(h)x(w), from a codebook of size 64K, which
|
102 |
+
corresponds to FSQ levels of (8,8,8,5,5,5).
|
103 |
+
- The discrete code, Bx6x(h)x(w), where the compression rate is
|
104 |
+
again (H/h x W/w), and channel dimension of 6.
|
105 |
+
"""
|
106 |
+
output_latent = self._enc_model(input_tensor)
|
107 |
+
if isinstance(output_latent, torch.Tensor):
|
108 |
+
return output_latent
|
109 |
+
return output_latent[:-1]
|
110 |
+
|
111 |
+
@torch.no_grad()
|
112 |
+
def forward(self, image: np.ndarray) -> np.ndarray:
|
113 |
+
"""Reconstructs an image using a pre-trained tokenizer.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
image: The input image BxHxWxC layout, range [0..255].
|
117 |
+
Returns:
|
118 |
+
The reconstructed image in range [0..255], layout BxHxWxC.
|
119 |
+
"""
|
120 |
+
padded_input_image, crop_region = pad_image_batch(image)
|
121 |
+
input_tensor = numpy2tensor(padded_input_image, dtype=self._dtype, device=self._device)
|
122 |
+
output_tensor = self.autoencode(input_tensor)
|
123 |
+
padded_output_image = tensor2numpy(output_tensor)
|
124 |
+
return unpad_image_batch(padded_output_image, crop_region)
|
cosmos_transfer1/auxiliary/tokenizer/inference/utils.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Utility functions for the inference libraries."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
from glob import glob
|
20 |
+
from typing import Any
|
21 |
+
|
22 |
+
import mediapy as media
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks import TokenizerModels
|
27 |
+
|
28 |
+
_DTYPE, _DEVICE = torch.bfloat16, "cuda"
|
29 |
+
_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max)
|
30 |
+
_SPATIAL_ALIGN = 16
|
31 |
+
_TEMPORAL_ALIGN = 8
|
32 |
+
|
33 |
+
|
34 |
+
def load_model(
|
35 |
+
jit_filepath: str = None,
|
36 |
+
tokenizer_config: dict[str, Any] = None,
|
37 |
+
device: str = "cuda",
|
38 |
+
) -> torch.nn.Module | torch.jit.ScriptModule:
|
39 |
+
"""Loads a torch.nn.Module from a filepath.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
jit_filepath: The filepath to the JIT-compiled model.
|
43 |
+
device: The device to load the model onto, default=cuda.
|
44 |
+
Returns:
|
45 |
+
The JIT compiled model loaded to device and on eval mode.
|
46 |
+
"""
|
47 |
+
if tokenizer_config is None:
|
48 |
+
return load_jit_model(jit_filepath, device)
|
49 |
+
full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
|
50 |
+
full_model.load_state_dict(ckpts.state_dict(), strict=False)
|
51 |
+
return full_model.eval().to(device)
|
52 |
+
|
53 |
+
|
54 |
+
def load_encoder_model(
|
55 |
+
jit_filepath: str = None,
|
56 |
+
tokenizer_config: dict[str, Any] = None,
|
57 |
+
device: str = "cuda",
|
58 |
+
) -> torch.nn.Module | torch.jit.ScriptModule:
|
59 |
+
"""Loads a torch.nn.Module from a filepath.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
jit_filepath: The filepath to the JIT-compiled model.
|
63 |
+
device: The device to load the model onto, default=cuda.
|
64 |
+
Returns:
|
65 |
+
The JIT compiled model loaded to device and on eval mode.
|
66 |
+
"""
|
67 |
+
if tokenizer_config is None:
|
68 |
+
return load_jit_model(jit_filepath, device)
|
69 |
+
full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
|
70 |
+
encoder_model = full_model.encoder_jit()
|
71 |
+
encoder_model.load_state_dict(ckpts.state_dict(), strict=False)
|
72 |
+
return encoder_model.eval().to(device)
|
73 |
+
|
74 |
+
|
75 |
+
def load_decoder_model(
|
76 |
+
jit_filepath: str = None,
|
77 |
+
tokenizer_config: dict[str, Any] = None,
|
78 |
+
device: str = "cuda",
|
79 |
+
) -> torch.nn.Module | torch.jit.ScriptModule:
|
80 |
+
"""Loads a torch.nn.Module from a filepath.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
jit_filepath: The filepath to the JIT-compiled model.
|
84 |
+
device: The device to load the model onto, default=cuda.
|
85 |
+
Returns:
|
86 |
+
The JIT compiled model loaded to device and on eval mode.
|
87 |
+
"""
|
88 |
+
if tokenizer_config is None:
|
89 |
+
return load_jit_model(jit_filepath, device)
|
90 |
+
full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
|
91 |
+
decoder_model = full_model.decoder_jit()
|
92 |
+
decoder_model.load_state_dict(ckpts.state_dict(), strict=False)
|
93 |
+
return decoder_model.eval().to(device)
|
94 |
+
|
95 |
+
|
96 |
+
def _load_pytorch_model(
|
97 |
+
jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda"
|
98 |
+
) -> torch.nn.Module:
|
99 |
+
"""Loads a torch.nn.Module from a filepath.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
jit_filepath: The filepath to the JIT-compiled model.
|
103 |
+
device: The device to load the model onto, default=cuda.
|
104 |
+
Returns:
|
105 |
+
The JIT compiled model loaded to device and on eval mode.
|
106 |
+
"""
|
107 |
+
tokenizer_name = tokenizer_config["name"]
|
108 |
+
model = TokenizerModels[tokenizer_name].value(**tokenizer_config)
|
109 |
+
ckpts = torch.jit.load(jit_filepath, map_location=device)
|
110 |
+
return model, ckpts
|
111 |
+
|
112 |
+
|
113 |
+
def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule:
|
114 |
+
"""Loads a torch.jit.ScriptModule from a filepath.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
jit_filepath: The filepath to the JIT-compiled model.
|
118 |
+
device: The device to load the model onto, default=cuda.
|
119 |
+
Returns:
|
120 |
+
The JIT compiled model loaded to device and on eval mode.
|
121 |
+
"""
|
122 |
+
model = torch.jit.load(jit_filepath, map_location=device)
|
123 |
+
return model.eval().to(device)
|
124 |
+
|
125 |
+
|
126 |
+
def save_jit_model(
|
127 |
+
model: torch.jit.ScriptModule | torch.jit.RecursiveScriptModule = None,
|
128 |
+
jit_filepath: str = None,
|
129 |
+
) -> None:
|
130 |
+
"""Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
model: JIT compiled model loaded onto `config.checkpoint.jit.device`.
|
134 |
+
jit_filepath: The filepath to the JIT-compiled model.
|
135 |
+
"""
|
136 |
+
torch.jit.save(model, jit_filepath)
|
137 |
+
|
138 |
+
|
139 |
+
def get_filepaths(input_pattern) -> list[str]:
|
140 |
+
"""Returns a list of filepaths from a pattern."""
|
141 |
+
filepaths = sorted(glob(str(input_pattern)))
|
142 |
+
return list(set(filepaths))
|
143 |
+
|
144 |
+
|
145 |
+
def get_output_filepath(filepath: str, output_dir: str = None) -> str:
|
146 |
+
"""Returns the output filepath for the given input filepath."""
|
147 |
+
output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions"
|
148 |
+
output_filepath = f"{output_dir}/{os.path.basename(filepath)}"
|
149 |
+
os.makedirs(output_dir, exist_ok=True)
|
150 |
+
return output_filepath
|
151 |
+
|
152 |
+
|
153 |
+
def read_image(filepath: str) -> np.ndarray:
|
154 |
+
"""Reads an image from a filepath.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
filepath: The filepath to the image.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype.
|
161 |
+
"""
|
162 |
+
image = media.read_image(filepath)
|
163 |
+
# convert the grey scale image to RGB
|
164 |
+
# since our tokenizers always assume 3-channel RGB image
|
165 |
+
if image.ndim == 2:
|
166 |
+
image = np.stack([image] * 3, axis=-1)
|
167 |
+
# convert RGBA to RGB
|
168 |
+
if image.shape[-1] == 4:
|
169 |
+
image = image[..., :3]
|
170 |
+
return image
|
171 |
+
|
172 |
+
|
173 |
+
def read_video(filepath: str) -> np.ndarray:
|
174 |
+
"""Reads a video from a filepath.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
filepath: The filepath to the video.
|
178 |
+
Returns:
|
179 |
+
The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype.
|
180 |
+
"""
|
181 |
+
video = media.read_video(filepath)
|
182 |
+
# convert the grey scale frame to RGB
|
183 |
+
# since our tokenizers always assume 3-channel video
|
184 |
+
if video.ndim == 3:
|
185 |
+
video = np.stack([video] * 3, axis=-1)
|
186 |
+
# convert RGBA to RGB
|
187 |
+
if video.shape[-1] == 4:
|
188 |
+
video = video[..., :3]
|
189 |
+
return video
|
190 |
+
|
191 |
+
|
192 |
+
def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray:
|
193 |
+
"""Resizes an image to have the short side of `short_size`.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
image: The image to resize, layout HxWxC, of any range.
|
197 |
+
short_size: The size of the short side.
|
198 |
+
Returns:
|
199 |
+
The resized image.
|
200 |
+
"""
|
201 |
+
if short_size is None:
|
202 |
+
return image
|
203 |
+
height, width = image.shape[-3:-1]
|
204 |
+
if height <= width:
|
205 |
+
height_new, width_new = short_size, int(width * short_size / height + 0.5)
|
206 |
+
width_new = width_new if width_new % 2 == 0 else width_new + 1
|
207 |
+
else:
|
208 |
+
height_new, width_new = (
|
209 |
+
int(height * short_size / width + 0.5),
|
210 |
+
short_size,
|
211 |
+
)
|
212 |
+
height_new = height_new if height_new % 2 == 0 else height_new + 1
|
213 |
+
return media.resize_image(image, shape=(height_new, width_new))
|
214 |
+
|
215 |
+
|
216 |
+
def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray:
|
217 |
+
"""Resizes a video to have the short side of `short_size`.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
video: The video to resize, layout TxHxWxC, of any range.
|
221 |
+
short_size: The size of the short side.
|
222 |
+
Returns:
|
223 |
+
The resized video.
|
224 |
+
"""
|
225 |
+
if short_size is None:
|
226 |
+
return video
|
227 |
+
height, width = video.shape[-3:-1]
|
228 |
+
if height <= width:
|
229 |
+
height_new, width_new = short_size, int(width * short_size / height + 0.5)
|
230 |
+
width_new = width_new if width_new % 2 == 0 else width_new + 1
|
231 |
+
else:
|
232 |
+
height_new, width_new = (
|
233 |
+
int(height * short_size / width + 0.5),
|
234 |
+
short_size,
|
235 |
+
)
|
236 |
+
height_new = height_new if height_new % 2 == 0 else height_new + 1
|
237 |
+
return media.resize_video(video, shape=(height_new, width_new))
|
238 |
+
|
239 |
+
|
240 |
+
def write_image(filepath: str, image: np.ndarray):
|
241 |
+
"""Writes an image to a filepath."""
|
242 |
+
return media.write_image(filepath, image)
|
243 |
+
|
244 |
+
|
245 |
+
def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None:
|
246 |
+
"""Writes a video to a filepath."""
|
247 |
+
return media.write_video(filepath, video, fps=fps)
|
248 |
+
|
249 |
+
|
250 |
+
def numpy2tensor(
|
251 |
+
input_image: np.ndarray,
|
252 |
+
dtype: torch.dtype = _DTYPE,
|
253 |
+
device: str = _DEVICE,
|
254 |
+
range_min: int = -1,
|
255 |
+
) -> torch.Tensor:
|
256 |
+
"""Converts image(dtype=np.uint8) to `dtype` in range [0..255].
|
257 |
+
|
258 |
+
Args:
|
259 |
+
input_image: A batch of images in range [0..255], BxHxWx3 layout.
|
260 |
+
Returns:
|
261 |
+
A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype.
|
262 |
+
"""
|
263 |
+
ndim = input_image.ndim
|
264 |
+
indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1]
|
265 |
+
image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F
|
266 |
+
if range_min == -1:
|
267 |
+
image = 2.0 * image - 1.0
|
268 |
+
return torch.from_numpy(image).to(dtype).to(device)
|
269 |
+
|
270 |
+
|
271 |
+
def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray:
|
272 |
+
"""Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255].
|
273 |
+
|
274 |
+
Args:
|
275 |
+
input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1].
|
276 |
+
Returns:
|
277 |
+
A numpy image of layout BxHxWx3, range [0..255], uint8 dtype.
|
278 |
+
"""
|
279 |
+
if range_min == -1:
|
280 |
+
input_tensor = (input_tensor.float() + 1.0) / 2.0
|
281 |
+
ndim = input_tensor.ndim
|
282 |
+
output_image = input_tensor.clamp(0, 1).cpu().numpy()
|
283 |
+
output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,))
|
284 |
+
return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8)
|
285 |
+
|
286 |
+
|
287 |
+
def pad_image_batch(batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN) -> tuple[np.ndarray, list[int]]:
|
288 |
+
"""Pads a batch of images to be divisible by `spatial_align`.
|
289 |
+
|
290 |
+
Args:
|
291 |
+
batch: The batch of images to pad, layout BxHxWx3, in any range.
|
292 |
+
align: The alignment to pad to.
|
293 |
+
Returns:
|
294 |
+
The padded batch and the crop region.
|
295 |
+
"""
|
296 |
+
height, width = batch.shape[1:3]
|
297 |
+
align = spatial_align
|
298 |
+
height_to_pad = (align - height % align) if height % align != 0 else 0
|
299 |
+
width_to_pad = (align - width % align) if width % align != 0 else 0
|
300 |
+
|
301 |
+
crop_region = [
|
302 |
+
height_to_pad >> 1,
|
303 |
+
width_to_pad >> 1,
|
304 |
+
height + (height_to_pad >> 1),
|
305 |
+
width + (width_to_pad >> 1),
|
306 |
+
]
|
307 |
+
batch = np.pad(
|
308 |
+
batch,
|
309 |
+
(
|
310 |
+
(0, 0),
|
311 |
+
(height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
|
312 |
+
(width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)),
|
313 |
+
(0, 0),
|
314 |
+
),
|
315 |
+
mode="constant",
|
316 |
+
)
|
317 |
+
return batch, crop_region
|
318 |
+
|
319 |
+
|
320 |
+
def pad_video_batch(
|
321 |
+
batch: np.ndarray,
|
322 |
+
temporal_align: int = _TEMPORAL_ALIGN,
|
323 |
+
spatial_align: int = _SPATIAL_ALIGN,
|
324 |
+
) -> tuple[np.ndarray, list[int]]:
|
325 |
+
"""Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`.
|
326 |
+
|
327 |
+
Zero pad spatially. Reflection pad temporally to handle causality better.
|
328 |
+
Args:
|
329 |
+
batch: The batch of videos to pad., layout BxFxHxWx3, in any range.
|
330 |
+
align: The alignment to pad to.
|
331 |
+
Returns:
|
332 |
+
The padded batch and the crop region.
|
333 |
+
"""
|
334 |
+
num_frames, height, width = batch.shape[-4:-1]
|
335 |
+
align = spatial_align
|
336 |
+
height_to_pad = (align - height % align) if height % align != 0 else 0
|
337 |
+
width_to_pad = (align - width % align) if width % align != 0 else 0
|
338 |
+
|
339 |
+
align = temporal_align
|
340 |
+
frames_to_pad = (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0
|
341 |
+
|
342 |
+
crop_region = [
|
343 |
+
frames_to_pad >> 1,
|
344 |
+
height_to_pad >> 1,
|
345 |
+
width_to_pad >> 1,
|
346 |
+
num_frames + (frames_to_pad >> 1),
|
347 |
+
height + (height_to_pad >> 1),
|
348 |
+
width + (width_to_pad >> 1),
|
349 |
+
]
|
350 |
+
batch = np.pad(
|
351 |
+
batch,
|
352 |
+
(
|
353 |
+
(0, 0),
|
354 |
+
(0, 0),
|
355 |
+
(height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
|
356 |
+
(width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)),
|
357 |
+
(0, 0),
|
358 |
+
),
|
359 |
+
mode="constant",
|
360 |
+
)
|
361 |
+
batch = np.pad(
|
362 |
+
batch,
|
363 |
+
(
|
364 |
+
(0, 0),
|
365 |
+
(frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)),
|
366 |
+
(0, 0),
|
367 |
+
(0, 0),
|
368 |
+
(0, 0),
|
369 |
+
),
|
370 |
+
mode="edge",
|
371 |
+
)
|
372 |
+
return batch, crop_region
|
373 |
+
|
374 |
+
|
375 |
+
def unpad_video_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray:
|
376 |
+
"""Unpads video with `crop_region`.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
batch: A batch of numpy videos, layout BxFxHxWxC.
|
380 |
+
crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices.
|
381 |
+
|
382 |
+
Returns:
|
383 |
+
np.ndarray: Cropped numpy video, layout BxFxHxWxC.
|
384 |
+
"""
|
385 |
+
assert len(crop_region) == 6, "crop_region should be len of 6."
|
386 |
+
f1, y1, x1, f2, y2, x2 = crop_region
|
387 |
+
return batch[..., f1:f2, y1:y2, x1:x2, :]
|
388 |
+
|
389 |
+
|
390 |
+
def unpad_image_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray:
|
391 |
+
"""Unpads image with `crop_region`.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
batch: A batch of numpy images, layout BxHxWxC.
|
395 |
+
crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices.
|
396 |
+
|
397 |
+
Returns:
|
398 |
+
np.ndarray: Cropped numpy image, layout BxHxWxC.
|
399 |
+
"""
|
400 |
+
assert len(crop_region) == 4, "crop_region should be len of 4."
|
401 |
+
y1, x1, y2, x2 = crop_region
|
402 |
+
return batch[..., y1:y2, x1:x2, :]
|
cosmos_transfer1/auxiliary/tokenizer/inference/video_cli.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""A CLI to run CausalVideoTokenizer on plain videos based on torch.jit.
|
17 |
+
|
18 |
+
Usage:
|
19 |
+
python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.video_cli \
|
20 |
+
--video_pattern 'path/to/video/samples/*.mp4' \
|
21 |
+
--output_dir ./reconstructions \
|
22 |
+
--checkpoint_enc ./checkpoints/<model-name>/encoder.jit \
|
23 |
+
--checkpoint_dec ./checkpoints/<model-name>/decoder.jit
|
24 |
+
|
25 |
+
Optionally, you can run the model in pure PyTorch mode:
|
26 |
+
python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.video_cli \
|
27 |
+
--video_pattern 'path/to/video/samples/*.mp4' \
|
28 |
+
--mode=torch \
|
29 |
+
--tokenizer_type=CV \
|
30 |
+
--temporal_compression=4 \
|
31 |
+
--spatial_compression=8 \
|
32 |
+
--checkpoint_enc ./checkpoints/<model-name>/encoder.jit \
|
33 |
+
--checkpoint_dec ./checkpoints/<model-name>/decoder.jit
|
34 |
+
"""
|
35 |
+
|
36 |
+
import os
|
37 |
+
import sys
|
38 |
+
from argparse import ArgumentParser, Namespace
|
39 |
+
from typing import Any
|
40 |
+
|
41 |
+
import numpy as np
|
42 |
+
from loguru import logger as logging
|
43 |
+
|
44 |
+
from cosmos_transfer1.auxiliary.tokenizer.inference.utils import (
|
45 |
+
get_filepaths,
|
46 |
+
get_output_filepath,
|
47 |
+
read_video,
|
48 |
+
resize_video,
|
49 |
+
write_video,
|
50 |
+
)
|
51 |
+
from cosmos_transfer1.auxiliary.tokenizer.inference.video_lib import CausalVideoTokenizer
|
52 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks import TokenizerConfigs
|
53 |
+
|
54 |
+
|
55 |
+
def _parse_args() -> tuple[Namespace, dict[str, Any]]:
|
56 |
+
parser = ArgumentParser(description="A CLI for CausalVideoTokenizer.")
|
57 |
+
parser.add_argument(
|
58 |
+
"--video_pattern",
|
59 |
+
type=str,
|
60 |
+
default="path/to/videos/*.mp4",
|
61 |
+
help="Glob pattern.",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--checkpoint",
|
65 |
+
type=str,
|
66 |
+
default=None,
|
67 |
+
help="JIT full Autoencoder model filepath.",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--checkpoint_enc",
|
71 |
+
type=str,
|
72 |
+
default=None,
|
73 |
+
help="JIT Encoder model filepath.",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--checkpoint_dec",
|
77 |
+
type=str,
|
78 |
+
default=None,
|
79 |
+
help="JIT Decoder model filepath.",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--tokenizer_type",
|
83 |
+
type=str,
|
84 |
+
choices=["CV", "DV"],
|
85 |
+
help="Specifies the tokenizer type.",
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--spatial_compression",
|
89 |
+
type=int,
|
90 |
+
choices=[8, 16],
|
91 |
+
default=8,
|
92 |
+
help="The spatial compression factor.",
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--temporal_compression",
|
96 |
+
type=int,
|
97 |
+
choices=[4, 8],
|
98 |
+
default=4,
|
99 |
+
help="The temporal compression factor.",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--mode",
|
103 |
+
type=str,
|
104 |
+
choices=["torch", "jit"],
|
105 |
+
default="jit",
|
106 |
+
help="Specify the backend: native 'torch' or 'jit' (default: 'jit')",
|
107 |
+
)
|
108 |
+
parser.add_argument(
|
109 |
+
"--short_size",
|
110 |
+
type=int,
|
111 |
+
default=None,
|
112 |
+
help="The size to resample inputs. None, by default.",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--temporal_window",
|
116 |
+
type=int,
|
117 |
+
default=17,
|
118 |
+
help="The temporal window to operate at a time.",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--dtype",
|
122 |
+
type=str,
|
123 |
+
default="bfloat16",
|
124 |
+
help="Sets the precision, default bfloat16.",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--device",
|
128 |
+
type=str,
|
129 |
+
default="cuda",
|
130 |
+
help="Device for invoking the model.",
|
131 |
+
)
|
132 |
+
parser.add_argument("--output_dir", type=str, default=None, help="Output directory.")
|
133 |
+
parser.add_argument(
|
134 |
+
"--output_fps",
|
135 |
+
type=float,
|
136 |
+
default=24.0,
|
137 |
+
help="Output frames-per-second (FPS).",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--save_input",
|
141 |
+
action="store_true",
|
142 |
+
help="If on, the input video will be be outputted too.",
|
143 |
+
)
|
144 |
+
|
145 |
+
args = parser.parse_args()
|
146 |
+
return args
|
147 |
+
|
148 |
+
|
149 |
+
logging.info("Initializes args ...")
|
150 |
+
args = _parse_args()
|
151 |
+
if args.mode == "torch" and args.tokenizer_type not in ["CV", "DV"]:
|
152 |
+
logging.error("'torch' backend requires the tokenizer_type of 'CV' or 'DV'.")
|
153 |
+
sys.exit(1)
|
154 |
+
|
155 |
+
|
156 |
+
def _run_eval() -> None:
|
157 |
+
"""Invokes JIT-compiled CausalVideoTokenizer on an input video."""
|
158 |
+
|
159 |
+
if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None:
|
160 |
+
logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.")
|
161 |
+
return
|
162 |
+
|
163 |
+
if args.mode == "torch":
|
164 |
+
tokenizer_config = TokenizerConfigs[args.tokenizer_type].value
|
165 |
+
tokenizer_config.update(dict(spatial_compression=args.spatial_compression))
|
166 |
+
tokenizer_config.update(dict(temporal_compression=args.temporal_compression))
|
167 |
+
else:
|
168 |
+
tokenizer_config = None
|
169 |
+
|
170 |
+
logging.info(
|
171 |
+
f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..."
|
172 |
+
)
|
173 |
+
autoencoder = CausalVideoTokenizer(
|
174 |
+
checkpoint=args.checkpoint,
|
175 |
+
checkpoint_enc=args.checkpoint_enc,
|
176 |
+
checkpoint_dec=args.checkpoint_dec,
|
177 |
+
tokenizer_config=tokenizer_config,
|
178 |
+
device=args.device,
|
179 |
+
dtype=args.dtype,
|
180 |
+
)
|
181 |
+
|
182 |
+
logging.info(f"Looking for files matching video_pattern={args.video_pattern} ...")
|
183 |
+
filepaths = get_filepaths(args.video_pattern)
|
184 |
+
logging.info(f"Found {len(filepaths)} videos from {args.video_pattern}.")
|
185 |
+
|
186 |
+
for filepath in filepaths:
|
187 |
+
logging.info(f"Reading video {filepath} ...")
|
188 |
+
video = read_video(filepath)
|
189 |
+
video = resize_video(video, short_size=args.short_size)
|
190 |
+
|
191 |
+
logging.info("Invoking the autoencoder model in ... ")
|
192 |
+
batch_video = video[np.newaxis, ...]
|
193 |
+
output_video = autoencoder(batch_video, temporal_window=args.temporal_window)[0]
|
194 |
+
logging.info("Constructing output filepath ...")
|
195 |
+
output_filepath = get_output_filepath(filepath, output_dir=args.output_dir)
|
196 |
+
logging.info(f"Outputing {output_filepath} ...")
|
197 |
+
write_video(output_filepath, output_video, fps=args.output_fps)
|
198 |
+
if args.save_input:
|
199 |
+
ext = os.path.splitext(output_filepath)[-1]
|
200 |
+
input_filepath = output_filepath.replace(ext, "_input" + ext)
|
201 |
+
write_video(input_filepath, video, fps=args.output_fps)
|
202 |
+
|
203 |
+
|
204 |
+
@logging.catch(reraise=True)
|
205 |
+
def main() -> None:
|
206 |
+
_run_eval()
|
207 |
+
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
main()
|
cosmos_transfer1/auxiliary/tokenizer/inference/video_lib.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""A library for Causal Video Tokenizer inference."""
|
17 |
+
|
18 |
+
from typing import Any
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from cosmos_transfer1.auxiliary.tokenizer.inference.utils import (
|
25 |
+
load_decoder_model,
|
26 |
+
load_encoder_model,
|
27 |
+
load_model,
|
28 |
+
numpy2tensor,
|
29 |
+
pad_video_batch,
|
30 |
+
tensor2numpy,
|
31 |
+
unpad_video_batch,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class CausalVideoTokenizer(torch.nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
checkpoint: str = None,
|
39 |
+
checkpoint_enc: str = None,
|
40 |
+
checkpoint_dec: str = None,
|
41 |
+
tokenizer_config: dict[str, Any] = None,
|
42 |
+
device: str = "cuda",
|
43 |
+
dtype: str = "bfloat16",
|
44 |
+
) -> None:
|
45 |
+
super().__init__()
|
46 |
+
self._device = device
|
47 |
+
self._dtype = getattr(torch, dtype)
|
48 |
+
self._full_model = (
|
49 |
+
load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None
|
50 |
+
)
|
51 |
+
self._enc_model = (
|
52 |
+
load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype)
|
53 |
+
if checkpoint_enc is not None
|
54 |
+
else None
|
55 |
+
)
|
56 |
+
self._dec_model = (
|
57 |
+
load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype)
|
58 |
+
if checkpoint_dec is not None
|
59 |
+
else None
|
60 |
+
)
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
64 |
+
"""Reconstrcuts a batch of video tensors after embedding into a latent.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
video: The input video Bx3xTxHxW layout, range [-1..1].
|
68 |
+
Returns:
|
69 |
+
The reconstructed video, layout Bx3xTxHxW, range [-1..1].
|
70 |
+
"""
|
71 |
+
if self._full_model is not None:
|
72 |
+
output_tensor = self._full_model(input_tensor)
|
73 |
+
output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor
|
74 |
+
else:
|
75 |
+
output_latent = self.encode(input_tensor)[0]
|
76 |
+
output_tensor = self.decode(output_latent)
|
77 |
+
return output_tensor
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
|
81 |
+
"""Encodes a numpy video into a CausalVideo latent or code.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1].
|
85 |
+
Returns:
|
86 |
+
For causal continuous video (CV) tokenizer, the tuple contains:
|
87 |
+
- The latent embedding, Bx16x(t)x(h)x(w), where the compression
|
88 |
+
rate is (T/t x H/h x W/w), and channel dimension of 16.
|
89 |
+
For causal discrete video (DV) tokenizer, the tuple contains:
|
90 |
+
1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which
|
91 |
+
is formed by FSQ levels of (8,8,8,5,5,5).
|
92 |
+
2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate
|
93 |
+
is again (T/t x H/h x W/w), and channel dimension of 6.
|
94 |
+
"""
|
95 |
+
assert input_tensor.ndim == 5, "input video should be of 5D."
|
96 |
+
|
97 |
+
output_latent = self._enc_model(input_tensor)
|
98 |
+
if isinstance(output_latent, torch.Tensor):
|
99 |
+
return output_latent
|
100 |
+
return output_latent[:-1]
|
101 |
+
|
102 |
+
@torch.no_grad()
|
103 |
+
def decode(self, input_latent: torch.Tensor) -> torch.Tensor:
|
104 |
+
"""Encodes a numpy video into a CausalVideo latent.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
input_latent: The continuous latent Bx16xtxhxw for CV,
|
108 |
+
or the discrete indices Bxtxhxw for DV.
|
109 |
+
Returns:
|
110 |
+
The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1].
|
111 |
+
"""
|
112 |
+
assert input_latent.ndim >= 4, "input latent should be of 5D for continuous and 4D for discrete."
|
113 |
+
return self._dec_model(input_latent)
|
114 |
+
|
115 |
+
def forward(
|
116 |
+
self,
|
117 |
+
video: np.ndarray,
|
118 |
+
temporal_window: int = 17,
|
119 |
+
) -> np.ndarray:
|
120 |
+
"""Reconstructs video using a pre-trained CausalTokenizer autoencoder.
|
121 |
+
Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer
|
122 |
+
in a sliding manner with a `temporal_window` size.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
video: The input video BxTxHxWx3 layout, range [0..255].
|
126 |
+
temporal_window: The length of the temporal window to process, default=25.
|
127 |
+
Returns:
|
128 |
+
The reconstructed video in range [0..255], layout BxTxHxWx3.
|
129 |
+
"""
|
130 |
+
assert video.ndim == 5, "input video should be of 5D."
|
131 |
+
num_frames = video.shape[1] # can be of any length.
|
132 |
+
output_video_list = []
|
133 |
+
for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)):
|
134 |
+
# Input video for the current window.
|
135 |
+
start, end = idx * temporal_window, (idx + 1) * temporal_window
|
136 |
+
input_video = video[:, start:end, ...]
|
137 |
+
|
138 |
+
# Spatio-temporally pad input_video so it's evenly divisible.
|
139 |
+
padded_input_video, crop_region = pad_video_batch(input_video)
|
140 |
+
input_tensor = numpy2tensor(padded_input_video, dtype=self._dtype, device=self._device)
|
141 |
+
output_tensor = self.autoencode(input_tensor)
|
142 |
+
padded_output_video = tensor2numpy(output_tensor)
|
143 |
+
output_video = unpad_video_batch(padded_output_video, crop_region)
|
144 |
+
|
145 |
+
output_video_list.append(output_video)
|
146 |
+
return np.concatenate(output_video_list, axis=1)
|
cosmos_transfer1/auxiliary/tokenizer/modules/__init__.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from enum import Enum
|
17 |
+
|
18 |
+
from cosmos_transfer1.auxiliary.tokenizer.modules.distributions import GaussianDistribution, IdentityDistribution
|
19 |
+
from cosmos_transfer1.auxiliary.tokenizer.modules.layers2d import Decoder, Encoder
|
20 |
+
from cosmos_transfer1.auxiliary.tokenizer.modules.layers3d import (
|
21 |
+
DecoderBase,
|
22 |
+
DecoderFactorized,
|
23 |
+
EncoderBase,
|
24 |
+
EncoderFactorized,
|
25 |
+
)
|
26 |
+
from cosmos_transfer1.auxiliary.tokenizer.modules.quantizers import (
|
27 |
+
FSQuantizer,
|
28 |
+
LFQuantizer,
|
29 |
+
ResidualFSQuantizer,
|
30 |
+
VectorQuantizer,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class EncoderType(Enum):
|
35 |
+
Default = Encoder
|
36 |
+
|
37 |
+
|
38 |
+
class DecoderType(Enum):
|
39 |
+
Default = Decoder
|
40 |
+
|
41 |
+
|
42 |
+
class Encoder3DType(Enum):
|
43 |
+
BASE = EncoderBase
|
44 |
+
FACTORIZED = EncoderFactorized
|
45 |
+
|
46 |
+
|
47 |
+
class Decoder3DType(Enum):
|
48 |
+
BASE = DecoderBase
|
49 |
+
FACTORIZED = DecoderFactorized
|
50 |
+
|
51 |
+
|
52 |
+
class ContinuousFormulation(Enum):
|
53 |
+
VAE = GaussianDistribution
|
54 |
+
AE = IdentityDistribution
|
55 |
+
|
56 |
+
|
57 |
+
class DiscreteQuantizer(Enum):
|
58 |
+
VQ = VectorQuantizer
|
59 |
+
LFQ = LFQuantizer
|
60 |
+
FSQ = FSQuantizer
|
61 |
+
RESFSQ = ResidualFSQuantizer
|
cosmos_transfer1/auxiliary/tokenizer/modules/distributions.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""The distribution modes to use for continuous image tokenizers."""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
|
21 |
+
class IdentityDistribution(torch.nn.Module):
|
22 |
+
def __init__(self):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
def forward(self, parameters):
|
26 |
+
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
|
27 |
+
|
28 |
+
|
29 |
+
class GaussianDistribution(torch.nn.Module):
|
30 |
+
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
|
31 |
+
super().__init__()
|
32 |
+
self.min_logvar = min_logvar
|
33 |
+
self.max_logvar = max_logvar
|
34 |
+
|
35 |
+
def sample(self, mean, logvar):
|
36 |
+
std = torch.exp(0.5 * logvar)
|
37 |
+
return mean + std * torch.randn_like(mean)
|
38 |
+
|
39 |
+
def forward(self, parameters):
|
40 |
+
mean, logvar = torch.chunk(parameters, 2, dim=1)
|
41 |
+
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
|
42 |
+
return self.sample(mean, logvar), (mean, logvar)
|
cosmos_transfer1/auxiliary/tokenizer/modules/layers2d.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""The model definition for Continuous 2D layers
|
17 |
+
|
18 |
+
Adapted from: https://github.com/CompVis/stable-diffusion/blob/
|
19 |
+
21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
|
20 |
+
|
21 |
+
[Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors]
|
22 |
+
https://github.com/CompVis/stable-diffusion/blob/
|
23 |
+
21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/LICENSE
|
24 |
+
"""
|
25 |
+
|
26 |
+
import math
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
# pytorch_diffusion + derived encoder decoder
|
31 |
+
import torch
|
32 |
+
import torch.nn as nn
|
33 |
+
import torch.nn.functional as F
|
34 |
+
from loguru import logger as logging
|
35 |
+
|
36 |
+
from cosmos_transfer1.auxiliary.tokenizer.modules.patching import Patcher, UnPatcher
|
37 |
+
from cosmos_transfer1.auxiliary.tokenizer.modules.utils import Normalize, nonlinearity
|
38 |
+
|
39 |
+
|
40 |
+
class Upsample(nn.Module):
|
41 |
+
def __init__(self, in_channels: int):
|
42 |
+
super().__init__()
|
43 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
44 |
+
|
45 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
46 |
+
x = x.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3)
|
47 |
+
return self.conv(x)
|
48 |
+
|
49 |
+
|
50 |
+
class Downsample(nn.Module):
|
51 |
+
def __init__(self, in_channels: int):
|
52 |
+
super().__init__()
|
53 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
54 |
+
|
55 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
56 |
+
pad = (0, 1, 0, 1)
|
57 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
58 |
+
return self.conv(x)
|
59 |
+
|
60 |
+
|
61 |
+
class ResnetBlock(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
*,
|
65 |
+
in_channels: int,
|
66 |
+
out_channels: int = None,
|
67 |
+
dropout: float,
|
68 |
+
**kwargs,
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
self.in_channels = in_channels
|
72 |
+
out_channels = in_channels if out_channels is None else out_channels
|
73 |
+
|
74 |
+
self.norm1 = Normalize(in_channels)
|
75 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
76 |
+
self.norm2 = Normalize(out_channels)
|
77 |
+
self.dropout = nn.Dropout(dropout)
|
78 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
79 |
+
self.nin_shortcut = (
|
80 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
81 |
+
if in_channels != out_channels
|
82 |
+
else nn.Identity()
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86 |
+
h = x
|
87 |
+
h = self.norm1(h)
|
88 |
+
h = nonlinearity(h)
|
89 |
+
h = self.conv1(h)
|
90 |
+
|
91 |
+
h = self.norm2(h)
|
92 |
+
h = nonlinearity(h)
|
93 |
+
h = self.dropout(h)
|
94 |
+
h = self.conv2(h)
|
95 |
+
|
96 |
+
x = self.nin_shortcut(x)
|
97 |
+
|
98 |
+
return x + h
|
99 |
+
|
100 |
+
|
101 |
+
class AttnBlock(nn.Module):
|
102 |
+
def __init__(self, in_channels: int):
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
self.norm = Normalize(in_channels)
|
106 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
107 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
108 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
109 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
110 |
+
|
111 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
112 |
+
# TODO (freda): Consider reusing implementations in Attn `imaginaire`,
|
113 |
+
# since than one is gonna be based on TransformerEngine's attn op,
|
114 |
+
# w/c could ease CP implementations.
|
115 |
+
h_ = x
|
116 |
+
h_ = self.norm(h_)
|
117 |
+
q = self.q(h_)
|
118 |
+
k = self.k(h_)
|
119 |
+
v = self.v(h_)
|
120 |
+
|
121 |
+
# compute attention
|
122 |
+
b, c, h, w = q.shape
|
123 |
+
q = q.reshape(b, c, h * w)
|
124 |
+
q = q.permute(0, 2, 1)
|
125 |
+
k = k.reshape(b, c, h * w)
|
126 |
+
w_ = torch.bmm(q, k)
|
127 |
+
w_ = w_ * (int(c) ** (-0.5))
|
128 |
+
w_ = F.softmax(w_, dim=2)
|
129 |
+
|
130 |
+
# attend to values
|
131 |
+
v = v.reshape(b, c, h * w)
|
132 |
+
w_ = w_.permute(0, 2, 1)
|
133 |
+
h_ = torch.bmm(v, w_)
|
134 |
+
h_ = h_.reshape(b, c, h, w)
|
135 |
+
|
136 |
+
h_ = self.proj_out(h_)
|
137 |
+
|
138 |
+
return x + h_
|
139 |
+
|
140 |
+
|
141 |
+
class Encoder(nn.Module):
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
in_channels: int,
|
145 |
+
channels: int,
|
146 |
+
channels_mult: list[int],
|
147 |
+
num_res_blocks: int,
|
148 |
+
attn_resolutions: list[int],
|
149 |
+
dropout: float,
|
150 |
+
resolution: int,
|
151 |
+
z_channels: int,
|
152 |
+
spatial_compression: int,
|
153 |
+
**ignore_kwargs,
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
self.num_resolutions = len(channels_mult)
|
157 |
+
self.num_res_blocks = num_res_blocks
|
158 |
+
|
159 |
+
# Patcher.
|
160 |
+
patch_size = ignore_kwargs.get("patch_size", 1)
|
161 |
+
self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
|
162 |
+
in_channels = in_channels * patch_size * patch_size
|
163 |
+
|
164 |
+
# calculate the number of downsample operations
|
165 |
+
self.num_downsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
|
166 |
+
assert (
|
167 |
+
self.num_downsamples <= self.num_resolutions
|
168 |
+
), f"we can only downsample {self.num_resolutions} times at most"
|
169 |
+
|
170 |
+
# downsampling
|
171 |
+
self.conv_in = torch.nn.Conv2d(in_channels, channels, kernel_size=3, stride=1, padding=1)
|
172 |
+
|
173 |
+
curr_res = resolution // patch_size
|
174 |
+
in_ch_mult = (1,) + tuple(channels_mult)
|
175 |
+
self.in_ch_mult = in_ch_mult
|
176 |
+
self.down = nn.ModuleList()
|
177 |
+
for i_level in range(self.num_resolutions):
|
178 |
+
block = nn.ModuleList()
|
179 |
+
attn = nn.ModuleList()
|
180 |
+
block_in = channels * in_ch_mult[i_level]
|
181 |
+
block_out = channels * channels_mult[i_level]
|
182 |
+
for _ in range(self.num_res_blocks):
|
183 |
+
block.append(
|
184 |
+
ResnetBlock(
|
185 |
+
in_channels=block_in,
|
186 |
+
out_channels=block_out,
|
187 |
+
dropout=dropout,
|
188 |
+
)
|
189 |
+
)
|
190 |
+
block_in = block_out
|
191 |
+
if curr_res in attn_resolutions:
|
192 |
+
attn.append(AttnBlock(block_in))
|
193 |
+
down = nn.Module()
|
194 |
+
down.block = block
|
195 |
+
down.attn = attn
|
196 |
+
if i_level < self.num_downsamples:
|
197 |
+
down.downsample = Downsample(block_in)
|
198 |
+
curr_res = curr_res // 2
|
199 |
+
self.down.append(down)
|
200 |
+
|
201 |
+
# middle
|
202 |
+
self.mid = nn.Module()
|
203 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
204 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
205 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
206 |
+
|
207 |
+
# end
|
208 |
+
self.norm_out = Normalize(block_in)
|
209 |
+
self.conv_out = torch.nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
|
210 |
+
|
211 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
212 |
+
x = self.patcher(x)
|
213 |
+
|
214 |
+
# downsampling
|
215 |
+
hs = [self.conv_in(x)]
|
216 |
+
for i_level in range(self.num_resolutions):
|
217 |
+
for i_block in range(self.num_res_blocks):
|
218 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
219 |
+
if len(self.down[i_level].attn) > 0:
|
220 |
+
h = self.down[i_level].attn[i_block](h)
|
221 |
+
hs.append(h)
|
222 |
+
if i_level < self.num_downsamples:
|
223 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
224 |
+
|
225 |
+
# middle
|
226 |
+
h = hs[-1]
|
227 |
+
h = self.mid.block_1(h)
|
228 |
+
h = self.mid.attn_1(h)
|
229 |
+
h = self.mid.block_2(h)
|
230 |
+
|
231 |
+
# end
|
232 |
+
h = self.norm_out(h)
|
233 |
+
h = nonlinearity(h)
|
234 |
+
h = self.conv_out(h)
|
235 |
+
return h
|
236 |
+
|
237 |
+
|
238 |
+
class Decoder(nn.Module):
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
out_channels: int,
|
242 |
+
channels: int,
|
243 |
+
channels_mult: list[int],
|
244 |
+
num_res_blocks: int,
|
245 |
+
attn_resolutions: int,
|
246 |
+
dropout: float,
|
247 |
+
resolution: int,
|
248 |
+
z_channels: int,
|
249 |
+
spatial_compression: int,
|
250 |
+
**ignore_kwargs,
|
251 |
+
):
|
252 |
+
super().__init__()
|
253 |
+
self.num_resolutions = len(channels_mult)
|
254 |
+
self.num_res_blocks = num_res_blocks
|
255 |
+
|
256 |
+
# UnPatcher.
|
257 |
+
patch_size = ignore_kwargs.get("patch_size", 1)
|
258 |
+
self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
|
259 |
+
out_ch = out_channels * patch_size * patch_size
|
260 |
+
|
261 |
+
# calculate the number of upsample operations
|
262 |
+
self.num_upsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
|
263 |
+
assert self.num_upsamples <= self.num_resolutions, f"we can only upsample {self.num_resolutions} times at most"
|
264 |
+
|
265 |
+
block_in = channels * channels_mult[self.num_resolutions - 1]
|
266 |
+
curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
|
267 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
268 |
+
logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
269 |
+
|
270 |
+
# z to block_in
|
271 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
272 |
+
|
273 |
+
# middle
|
274 |
+
self.mid = nn.Module()
|
275 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
276 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
277 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
278 |
+
|
279 |
+
# upsampling
|
280 |
+
self.up = nn.ModuleList()
|
281 |
+
for i_level in reversed(range(self.num_resolutions)):
|
282 |
+
block = nn.ModuleList()
|
283 |
+
attn = nn.ModuleList()
|
284 |
+
block_out = channels * channels_mult[i_level]
|
285 |
+
for _ in range(self.num_res_blocks + 1):
|
286 |
+
block.append(
|
287 |
+
ResnetBlock(
|
288 |
+
in_channels=block_in,
|
289 |
+
out_channels=block_out,
|
290 |
+
dropout=dropout,
|
291 |
+
)
|
292 |
+
)
|
293 |
+
block_in = block_out
|
294 |
+
if curr_res in attn_resolutions:
|
295 |
+
attn.append(AttnBlock(block_in))
|
296 |
+
up = nn.Module()
|
297 |
+
up.block = block
|
298 |
+
up.attn = attn
|
299 |
+
if i_level >= (self.num_resolutions - self.num_upsamples):
|
300 |
+
up.upsample = Upsample(block_in)
|
301 |
+
curr_res = curr_res * 2
|
302 |
+
self.up.insert(0, up)
|
303 |
+
|
304 |
+
# end
|
305 |
+
self.norm_out = Normalize(block_in)
|
306 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
307 |
+
|
308 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
309 |
+
h = self.conv_in(z)
|
310 |
+
|
311 |
+
# middle
|
312 |
+
h = self.mid.block_1(h)
|
313 |
+
h = self.mid.attn_1(h)
|
314 |
+
h = self.mid.block_2(h)
|
315 |
+
|
316 |
+
# upsampling
|
317 |
+
for i_level in reversed(range(self.num_resolutions)):
|
318 |
+
for i_block in range(self.num_res_blocks + 1):
|
319 |
+
h = self.up[i_level].block[i_block](h)
|
320 |
+
if len(self.up[i_level].attn) > 0:
|
321 |
+
h = self.up[i_level].attn[i_block](h)
|
322 |
+
if i_level >= (self.num_resolutions - self.num_upsamples):
|
323 |
+
h = self.up[i_level].upsample(h)
|
324 |
+
|
325 |
+
h = self.norm_out(h)
|
326 |
+
h = nonlinearity(h)
|
327 |
+
h = self.conv_out(h)
|
328 |
+
h = self.unpatcher(h)
|
329 |
+
return h
|
cosmos_transfer1/auxiliary/tokenizer/modules/layers3d.py
ADDED
@@ -0,0 +1,969 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""The model definition for 3D layers
|
17 |
+
|
18 |
+
Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/
|
19 |
+
9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889
|
20 |
+
|
21 |
+
[MIT License Copyright (c) 2023 Phil Wang]
|
22 |
+
https://github.com/lucidrains/magvit2-pytorch/blob/
|
23 |
+
9f49074179c912736e617d61b32be367eb5f993a/LICENSE
|
24 |
+
"""
|
25 |
+
import math
|
26 |
+
from typing import Tuple, Union
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torch.nn as nn
|
31 |
+
import torch.nn.functional as F
|
32 |
+
from loguru import logger as logging
|
33 |
+
|
34 |
+
from cosmos_transfer1.auxiliary.tokenizer.modules.patching import Patcher, Patcher3D, UnPatcher, UnPatcher3D
|
35 |
+
from cosmos_transfer1.auxiliary.tokenizer.modules.utils import (
|
36 |
+
CausalNormalize,
|
37 |
+
batch2space,
|
38 |
+
batch2time,
|
39 |
+
cast_tuple,
|
40 |
+
is_odd,
|
41 |
+
nonlinearity,
|
42 |
+
replication_pad,
|
43 |
+
space2batch,
|
44 |
+
time2batch,
|
45 |
+
)
|
46 |
+
|
47 |
+
_LEGACY_NUM_GROUPS = 32
|
48 |
+
|
49 |
+
|
50 |
+
class CausalConv3d(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
chan_in: int = 1,
|
54 |
+
chan_out: int = 1,
|
55 |
+
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
56 |
+
pad_mode: str = "constant",
|
57 |
+
**kwargs,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
kernel_size = cast_tuple(kernel_size, 3)
|
61 |
+
|
62 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
63 |
+
|
64 |
+
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
65 |
+
|
66 |
+
dilation = kwargs.pop("dilation", 1)
|
67 |
+
stride = kwargs.pop("stride", 1)
|
68 |
+
time_stride = kwargs.pop("time_stride", 1)
|
69 |
+
time_dilation = kwargs.pop("time_dilation", 1)
|
70 |
+
padding = kwargs.pop("padding", 1)
|
71 |
+
|
72 |
+
self.pad_mode = pad_mode
|
73 |
+
time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride)
|
74 |
+
self.time_pad = time_pad
|
75 |
+
|
76 |
+
self.spatial_pad = (padding, padding, padding, padding)
|
77 |
+
|
78 |
+
stride = (time_stride, stride, stride)
|
79 |
+
dilation = (time_dilation, dilation, dilation)
|
80 |
+
self.conv3d = nn.Conv3d(
|
81 |
+
chan_in,
|
82 |
+
chan_out,
|
83 |
+
kernel_size,
|
84 |
+
stride=stride,
|
85 |
+
dilation=dilation,
|
86 |
+
**kwargs,
|
87 |
+
)
|
88 |
+
|
89 |
+
def _replication_pad(self, x: torch.Tensor) -> torch.Tensor:
|
90 |
+
x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1)
|
91 |
+
x = torch.cat([x_prev, x], dim=2)
|
92 |
+
padding = self.spatial_pad + (0, 0)
|
93 |
+
return F.pad(x, padding, mode=self.pad_mode, value=0.0)
|
94 |
+
|
95 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
96 |
+
x = self._replication_pad(x)
|
97 |
+
return self.conv3d(x)
|
98 |
+
|
99 |
+
|
100 |
+
class CausalUpsample3d(nn.Module):
|
101 |
+
def __init__(self, in_channels: int) -> None:
|
102 |
+
super().__init__()
|
103 |
+
self.conv = CausalConv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
104 |
+
|
105 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
106 |
+
x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
|
107 |
+
time_factor = 1.0 + 1.0 * (x.shape[2] > 1)
|
108 |
+
if isinstance(time_factor, torch.Tensor):
|
109 |
+
time_factor = time_factor.item()
|
110 |
+
x = x.repeat_interleave(int(time_factor), dim=2)
|
111 |
+
# TODO(freda): Check if this causes temporal inconsistency.
|
112 |
+
# Shoule reverse the order of the following two ops,
|
113 |
+
# better perf and better temporal smoothness.
|
114 |
+
x = self.conv(x)
|
115 |
+
return x[..., int(time_factor - 1) :, :, :]
|
116 |
+
|
117 |
+
|
118 |
+
class CausalDownsample3d(nn.Module):
|
119 |
+
def __init__(self, in_channels: int) -> None:
|
120 |
+
super().__init__()
|
121 |
+
self.conv = CausalConv3d(
|
122 |
+
in_channels,
|
123 |
+
in_channels,
|
124 |
+
kernel_size=3,
|
125 |
+
stride=2,
|
126 |
+
time_stride=2,
|
127 |
+
padding=0,
|
128 |
+
)
|
129 |
+
|
130 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
131 |
+
pad = (0, 1, 0, 1, 0, 0)
|
132 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
133 |
+
x = replication_pad(x)
|
134 |
+
x = self.conv(x)
|
135 |
+
return x
|
136 |
+
|
137 |
+
|
138 |
+
class CausalHybridUpsample3d(nn.Module):
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
in_channels: int,
|
142 |
+
spatial_up: bool = True,
|
143 |
+
temporal_up: bool = True,
|
144 |
+
**kwargs,
|
145 |
+
) -> None:
|
146 |
+
super().__init__()
|
147 |
+
self.conv1 = CausalConv3d(
|
148 |
+
in_channels,
|
149 |
+
in_channels,
|
150 |
+
kernel_size=(3, 1, 1),
|
151 |
+
stride=1,
|
152 |
+
time_stride=1,
|
153 |
+
padding=0,
|
154 |
+
)
|
155 |
+
self.conv2 = CausalConv3d(
|
156 |
+
in_channels,
|
157 |
+
in_channels,
|
158 |
+
kernel_size=(1, 3, 3),
|
159 |
+
stride=1,
|
160 |
+
time_stride=1,
|
161 |
+
padding=1,
|
162 |
+
)
|
163 |
+
self.conv3 = CausalConv3d(
|
164 |
+
in_channels,
|
165 |
+
in_channels,
|
166 |
+
kernel_size=1,
|
167 |
+
stride=1,
|
168 |
+
time_stride=1,
|
169 |
+
padding=0,
|
170 |
+
)
|
171 |
+
self.spatial_up = spatial_up
|
172 |
+
self.temporal_up = temporal_up
|
173 |
+
|
174 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
175 |
+
if not self.spatial_up and not self.temporal_up:
|
176 |
+
return x
|
177 |
+
|
178 |
+
# hybrid upsample temporally.
|
179 |
+
if self.temporal_up:
|
180 |
+
time_factor = 1.0 + 1.0 * (x.shape[2] > 1)
|
181 |
+
if isinstance(time_factor, torch.Tensor):
|
182 |
+
time_factor = time_factor.item()
|
183 |
+
x = x.repeat_interleave(int(time_factor), dim=2)
|
184 |
+
x = x[..., int(time_factor - 1) :, :, :]
|
185 |
+
x = self.conv1(x) + x
|
186 |
+
|
187 |
+
# hybrid upsample spatially.
|
188 |
+
if self.spatial_up:
|
189 |
+
x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
|
190 |
+
x = self.conv2(x) + x
|
191 |
+
|
192 |
+
# final 1x1x1 conv.
|
193 |
+
x = self.conv3(x)
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
class CausalHybridDownsample3d(nn.Module):
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
in_channels: int,
|
201 |
+
spatial_down: bool = True,
|
202 |
+
temporal_down: bool = True,
|
203 |
+
**kwargs,
|
204 |
+
) -> None:
|
205 |
+
super().__init__()
|
206 |
+
self.conv1 = CausalConv3d(
|
207 |
+
in_channels,
|
208 |
+
in_channels,
|
209 |
+
kernel_size=(1, 3, 3),
|
210 |
+
stride=2,
|
211 |
+
time_stride=1,
|
212 |
+
padding=0,
|
213 |
+
)
|
214 |
+
self.conv2 = CausalConv3d(
|
215 |
+
in_channels,
|
216 |
+
in_channels,
|
217 |
+
kernel_size=(3, 1, 1),
|
218 |
+
stride=1,
|
219 |
+
time_stride=2,
|
220 |
+
padding=0,
|
221 |
+
)
|
222 |
+
self.conv3 = CausalConv3d(
|
223 |
+
in_channels,
|
224 |
+
in_channels,
|
225 |
+
kernel_size=1,
|
226 |
+
stride=1,
|
227 |
+
time_stride=1,
|
228 |
+
padding=0,
|
229 |
+
)
|
230 |
+
self.spatial_down = spatial_down
|
231 |
+
self.temporal_down = temporal_down
|
232 |
+
|
233 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
234 |
+
if not self.spatial_down and not self.temporal_down:
|
235 |
+
return x
|
236 |
+
|
237 |
+
# hybrid downsample spatially.
|
238 |
+
if self.spatial_down:
|
239 |
+
pad = (0, 1, 0, 1, 0, 0)
|
240 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
241 |
+
x1 = self.conv1(x)
|
242 |
+
x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
243 |
+
x = x1 + x2
|
244 |
+
|
245 |
+
# hybrid downsample temporally.
|
246 |
+
if self.temporal_down:
|
247 |
+
x = replication_pad(x)
|
248 |
+
x1 = self.conv2(x)
|
249 |
+
x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1))
|
250 |
+
x = x1 + x2
|
251 |
+
|
252 |
+
# final 1x1x1 conv.
|
253 |
+
x = self.conv3(x)
|
254 |
+
return x
|
255 |
+
|
256 |
+
|
257 |
+
class CausalResnetBlock3d(nn.Module):
|
258 |
+
def __init__(
|
259 |
+
self,
|
260 |
+
*,
|
261 |
+
in_channels: int,
|
262 |
+
out_channels: int = None,
|
263 |
+
dropout: float,
|
264 |
+
num_groups: int,
|
265 |
+
) -> None:
|
266 |
+
super().__init__()
|
267 |
+
self.in_channels = in_channels
|
268 |
+
out_channels = in_channels if out_channels is None else out_channels
|
269 |
+
|
270 |
+
self.norm1 = CausalNormalize(in_channels, num_groups=num_groups)
|
271 |
+
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
272 |
+
self.norm2 = CausalNormalize(out_channels, num_groups=num_groups)
|
273 |
+
self.dropout = torch.nn.Dropout(dropout)
|
274 |
+
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
275 |
+
self.nin_shortcut = (
|
276 |
+
CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
277 |
+
if in_channels != out_channels
|
278 |
+
else nn.Identity()
|
279 |
+
)
|
280 |
+
|
281 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
282 |
+
h = x
|
283 |
+
h = self.norm1(h)
|
284 |
+
h = nonlinearity(h)
|
285 |
+
h = self.conv1(h)
|
286 |
+
|
287 |
+
h = self.norm2(h)
|
288 |
+
h = nonlinearity(h)
|
289 |
+
h = self.dropout(h)
|
290 |
+
h = self.conv2(h)
|
291 |
+
x = self.nin_shortcut(x)
|
292 |
+
|
293 |
+
return x + h
|
294 |
+
|
295 |
+
|
296 |
+
class CausalResnetBlockFactorized3d(nn.Module):
|
297 |
+
def __init__(
|
298 |
+
self,
|
299 |
+
*,
|
300 |
+
in_channels: int,
|
301 |
+
out_channels: int = None,
|
302 |
+
dropout: float,
|
303 |
+
num_groups: int,
|
304 |
+
) -> None:
|
305 |
+
super().__init__()
|
306 |
+
self.in_channels = in_channels
|
307 |
+
out_channels = in_channels if out_channels is None else out_channels
|
308 |
+
|
309 |
+
self.norm1 = CausalNormalize(in_channels, num_groups=1)
|
310 |
+
self.conv1 = nn.Sequential(
|
311 |
+
CausalConv3d(
|
312 |
+
in_channels,
|
313 |
+
out_channels,
|
314 |
+
kernel_size=(1, 3, 3),
|
315 |
+
stride=1,
|
316 |
+
padding=1,
|
317 |
+
),
|
318 |
+
CausalConv3d(
|
319 |
+
out_channels,
|
320 |
+
out_channels,
|
321 |
+
kernel_size=(3, 1, 1),
|
322 |
+
stride=1,
|
323 |
+
padding=0,
|
324 |
+
),
|
325 |
+
)
|
326 |
+
self.norm2 = CausalNormalize(out_channels, num_groups=num_groups)
|
327 |
+
self.dropout = torch.nn.Dropout(dropout)
|
328 |
+
self.conv2 = nn.Sequential(
|
329 |
+
CausalConv3d(
|
330 |
+
out_channels,
|
331 |
+
out_channels,
|
332 |
+
kernel_size=(1, 3, 3),
|
333 |
+
stride=1,
|
334 |
+
padding=1,
|
335 |
+
),
|
336 |
+
CausalConv3d(
|
337 |
+
out_channels,
|
338 |
+
out_channels,
|
339 |
+
kernel_size=(3, 1, 1),
|
340 |
+
stride=1,
|
341 |
+
padding=0,
|
342 |
+
),
|
343 |
+
)
|
344 |
+
self.nin_shortcut = (
|
345 |
+
CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
346 |
+
if in_channels != out_channels
|
347 |
+
else nn.Identity()
|
348 |
+
)
|
349 |
+
|
350 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
351 |
+
h = x
|
352 |
+
h = self.norm1(h)
|
353 |
+
h = nonlinearity(h)
|
354 |
+
h = self.conv1(h)
|
355 |
+
|
356 |
+
h = self.norm2(h)
|
357 |
+
h = nonlinearity(h)
|
358 |
+
h = self.dropout(h)
|
359 |
+
h = self.conv2(h)
|
360 |
+
x = self.nin_shortcut(x)
|
361 |
+
|
362 |
+
return x + h
|
363 |
+
|
364 |
+
|
365 |
+
class CausalAttnBlock(nn.Module):
|
366 |
+
def __init__(self, in_channels: int, num_groups: int) -> None:
|
367 |
+
super().__init__()
|
368 |
+
|
369 |
+
self.norm = CausalNormalize(in_channels, num_groups=num_groups)
|
370 |
+
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
371 |
+
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
372 |
+
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
373 |
+
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
374 |
+
|
375 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
376 |
+
h_ = x
|
377 |
+
h_ = self.norm(h_)
|
378 |
+
q = self.q(h_)
|
379 |
+
k = self.k(h_)
|
380 |
+
v = self.v(h_)
|
381 |
+
|
382 |
+
# compute attention
|
383 |
+
q, batch_size = time2batch(q)
|
384 |
+
k, batch_size = time2batch(k)
|
385 |
+
v, batch_size = time2batch(v)
|
386 |
+
|
387 |
+
b, c, h, w = q.shape
|
388 |
+
q = q.reshape(b, c, h * w)
|
389 |
+
q = q.permute(0, 2, 1)
|
390 |
+
k = k.reshape(b, c, h * w)
|
391 |
+
w_ = torch.bmm(q, k)
|
392 |
+
w_ = w_ * (int(c) ** (-0.5))
|
393 |
+
w_ = F.softmax(w_, dim=2)
|
394 |
+
|
395 |
+
# attend to values
|
396 |
+
v = v.reshape(b, c, h * w)
|
397 |
+
w_ = w_.permute(0, 2, 1)
|
398 |
+
h_ = torch.bmm(v, w_)
|
399 |
+
h_ = h_.reshape(b, c, h, w)
|
400 |
+
|
401 |
+
h_ = batch2time(h_, batch_size)
|
402 |
+
h_ = self.proj_out(h_)
|
403 |
+
return x + h_
|
404 |
+
|
405 |
+
|
406 |
+
class CausalTemporalAttnBlock(nn.Module):
|
407 |
+
def __init__(self, in_channels: int, num_groups: int) -> None:
|
408 |
+
super().__init__()
|
409 |
+
|
410 |
+
self.norm = CausalNormalize(in_channels, num_groups=num_groups)
|
411 |
+
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
412 |
+
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
413 |
+
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
414 |
+
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
415 |
+
|
416 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
417 |
+
h_ = x
|
418 |
+
h_ = self.norm(h_)
|
419 |
+
q = self.q(h_)
|
420 |
+
k = self.k(h_)
|
421 |
+
v = self.v(h_)
|
422 |
+
|
423 |
+
# compute attention
|
424 |
+
q, batch_size, height = space2batch(q)
|
425 |
+
k, _, _ = space2batch(k)
|
426 |
+
v, _, _ = space2batch(v)
|
427 |
+
|
428 |
+
bhw, c, t = q.shape
|
429 |
+
q = q.permute(0, 2, 1) # (bhw, t, c)
|
430 |
+
k = k.permute(0, 2, 1) # (bhw, t, c)
|
431 |
+
v = v.permute(0, 2, 1) # (bhw, t, c)
|
432 |
+
|
433 |
+
w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t)
|
434 |
+
w_ = w_ * (int(c) ** (-0.5))
|
435 |
+
|
436 |
+
# Apply causal mask
|
437 |
+
mask = torch.tril(torch.ones_like(w_))
|
438 |
+
w_ = w_.masked_fill(mask == 0, float("-inf"))
|
439 |
+
w_ = F.softmax(w_, dim=2)
|
440 |
+
|
441 |
+
# attend to values
|
442 |
+
h_ = torch.bmm(w_, v) # (bhw, t, c)
|
443 |
+
h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t)
|
444 |
+
|
445 |
+
h_ = batch2space(h_, batch_size, height)
|
446 |
+
h_ = self.proj_out(h_)
|
447 |
+
return x + h_
|
448 |
+
|
449 |
+
|
450 |
+
class EncoderBase(nn.Module):
|
451 |
+
def __init__(
|
452 |
+
self,
|
453 |
+
in_channels: int,
|
454 |
+
channels: int,
|
455 |
+
channels_mult: list[int],
|
456 |
+
num_res_blocks: int,
|
457 |
+
attn_resolutions: list[int],
|
458 |
+
dropout: float,
|
459 |
+
resolution: int,
|
460 |
+
z_channels: int,
|
461 |
+
**ignore_kwargs,
|
462 |
+
) -> None:
|
463 |
+
super().__init__()
|
464 |
+
self.num_resolutions = len(channels_mult)
|
465 |
+
self.num_res_blocks = num_res_blocks
|
466 |
+
|
467 |
+
# Patcher.
|
468 |
+
patch_size = ignore_kwargs.get("patch_size", 1)
|
469 |
+
self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
|
470 |
+
in_channels = in_channels * patch_size * patch_size
|
471 |
+
|
472 |
+
# downsampling
|
473 |
+
self.conv_in = CausalConv3d(in_channels, channels, kernel_size=3, stride=1, padding=1)
|
474 |
+
|
475 |
+
# num of groups for GroupNorm, num_groups=1 for LayerNorm.
|
476 |
+
num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS)
|
477 |
+
curr_res = resolution // patch_size
|
478 |
+
in_ch_mult = (1,) + tuple(channels_mult)
|
479 |
+
self.in_ch_mult = in_ch_mult
|
480 |
+
self.down = nn.ModuleList()
|
481 |
+
for i_level in range(self.num_resolutions):
|
482 |
+
block = nn.ModuleList()
|
483 |
+
attn = nn.ModuleList()
|
484 |
+
block_in = channels * in_ch_mult[i_level]
|
485 |
+
block_out = channels * channels_mult[i_level]
|
486 |
+
for _ in range(self.num_res_blocks):
|
487 |
+
block.append(
|
488 |
+
CausalResnetBlock3d(
|
489 |
+
in_channels=block_in,
|
490 |
+
out_channels=block_out,
|
491 |
+
dropout=dropout,
|
492 |
+
num_groups=num_groups,
|
493 |
+
)
|
494 |
+
)
|
495 |
+
block_in = block_out
|
496 |
+
if curr_res in attn_resolutions:
|
497 |
+
attn.append(CausalAttnBlock(block_in, num_groups=num_groups))
|
498 |
+
down = nn.Module()
|
499 |
+
down.block = block
|
500 |
+
down.attn = attn
|
501 |
+
if i_level != self.num_resolutions - 1:
|
502 |
+
down.downsample = CausalDownsample3d(block_in)
|
503 |
+
curr_res = curr_res // 2
|
504 |
+
self.down.append(down)
|
505 |
+
|
506 |
+
# middle
|
507 |
+
self.mid = nn.Module()
|
508 |
+
self.mid.block_1 = CausalResnetBlock3d(
|
509 |
+
in_channels=block_in,
|
510 |
+
out_channels=block_in,
|
511 |
+
dropout=dropout,
|
512 |
+
num_groups=num_groups,
|
513 |
+
)
|
514 |
+
self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups)
|
515 |
+
self.mid.block_2 = CausalResnetBlock3d(
|
516 |
+
in_channels=block_in,
|
517 |
+
out_channels=block_in,
|
518 |
+
dropout=dropout,
|
519 |
+
num_groups=num_groups,
|
520 |
+
)
|
521 |
+
|
522 |
+
# end
|
523 |
+
self.norm_out = CausalNormalize(block_in, num_groups=num_groups)
|
524 |
+
self.conv_out = CausalConv3d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
|
525 |
+
|
526 |
+
def patcher3d(self, x: torch.Tensor) -> torch.Tensor:
|
527 |
+
x, batch_size = time2batch(x)
|
528 |
+
x = self.patcher(x)
|
529 |
+
x = batch2time(x, batch_size)
|
530 |
+
return x
|
531 |
+
|
532 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
533 |
+
x = self.patcher3d(x)
|
534 |
+
|
535 |
+
# downsampling
|
536 |
+
hs = [self.conv_in(x)]
|
537 |
+
for i_level in range(self.num_resolutions):
|
538 |
+
for i_block in range(self.num_res_blocks):
|
539 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
540 |
+
if len(self.down[i_level].attn) > 0:
|
541 |
+
h = self.down[i_level].attn[i_block](h)
|
542 |
+
hs.append(h)
|
543 |
+
if i_level != self.num_resolutions - 1:
|
544 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
545 |
+
else:
|
546 |
+
# temporal downsample (last level)
|
547 |
+
time_factor = 1 + 1 * (hs[-1].shape[2] > 1)
|
548 |
+
if isinstance(time_factor, torch.Tensor):
|
549 |
+
time_factor = time_factor.item()
|
550 |
+
hs[-1] = replication_pad(hs[-1])
|
551 |
+
hs.append(
|
552 |
+
F.avg_pool3d(
|
553 |
+
hs[-1],
|
554 |
+
kernel_size=[time_factor, 1, 1],
|
555 |
+
stride=[2, 1, 1],
|
556 |
+
)
|
557 |
+
)
|
558 |
+
|
559 |
+
# middle
|
560 |
+
h = hs[-1]
|
561 |
+
h = self.mid.block_1(h)
|
562 |
+
h = self.mid.attn_1(h)
|
563 |
+
h = self.mid.block_2(h)
|
564 |
+
|
565 |
+
# end
|
566 |
+
h = self.norm_out(h)
|
567 |
+
h = nonlinearity(h)
|
568 |
+
h = self.conv_out(h)
|
569 |
+
return h
|
570 |
+
|
571 |
+
|
572 |
+
class DecoderBase(nn.Module):
|
573 |
+
def __init__(
|
574 |
+
self,
|
575 |
+
out_channels: int,
|
576 |
+
channels: int,
|
577 |
+
channels_mult: list[int],
|
578 |
+
num_res_blocks: int,
|
579 |
+
attn_resolutions: list[int],
|
580 |
+
dropout: float,
|
581 |
+
resolution: int,
|
582 |
+
z_channels: int,
|
583 |
+
**ignore_kwargs,
|
584 |
+
):
|
585 |
+
super().__init__()
|
586 |
+
self.num_resolutions = len(channels_mult)
|
587 |
+
self.num_res_blocks = num_res_blocks
|
588 |
+
|
589 |
+
# UnPatcher.
|
590 |
+
patch_size = ignore_kwargs.get("patch_size", 1)
|
591 |
+
self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
|
592 |
+
out_ch = out_channels * patch_size * patch_size
|
593 |
+
|
594 |
+
block_in = channels * channels_mult[self.num_resolutions - 1]
|
595 |
+
curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
|
596 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
597 |
+
logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
598 |
+
|
599 |
+
# z to block_in
|
600 |
+
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
601 |
+
|
602 |
+
# num of groups for GroupNorm, num_groups=1 for LayerNorm.
|
603 |
+
num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS)
|
604 |
+
|
605 |
+
# middle
|
606 |
+
self.mid = nn.Module()
|
607 |
+
self.mid.block_1 = CausalResnetBlock3d(
|
608 |
+
in_channels=block_in,
|
609 |
+
out_channels=block_in,
|
610 |
+
dropout=dropout,
|
611 |
+
num_groups=num_groups,
|
612 |
+
)
|
613 |
+
self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups)
|
614 |
+
self.mid.block_2 = CausalResnetBlock3d(
|
615 |
+
in_channels=block_in,
|
616 |
+
out_channels=block_in,
|
617 |
+
dropout=dropout,
|
618 |
+
num_groups=num_groups,
|
619 |
+
)
|
620 |
+
|
621 |
+
# upsampling
|
622 |
+
self.up = nn.ModuleList()
|
623 |
+
for i_level in reversed(range(self.num_resolutions)):
|
624 |
+
block = nn.ModuleList()
|
625 |
+
attn = nn.ModuleList()
|
626 |
+
block_out = channels * channels_mult[i_level]
|
627 |
+
for _ in range(self.num_res_blocks + 1):
|
628 |
+
block.append(
|
629 |
+
CausalResnetBlock3d(
|
630 |
+
in_channels=block_in,
|
631 |
+
out_channels=block_out,
|
632 |
+
dropout=dropout,
|
633 |
+
num_groups=num_groups,
|
634 |
+
)
|
635 |
+
)
|
636 |
+
block_in = block_out
|
637 |
+
if curr_res in attn_resolutions:
|
638 |
+
attn.append(CausalAttnBlock(block_in, num_groups=num_groups))
|
639 |
+
up = nn.Module()
|
640 |
+
up.block = block
|
641 |
+
up.attn = attn
|
642 |
+
if i_level != 0:
|
643 |
+
up.upsample = CausalUpsample3d(block_in)
|
644 |
+
curr_res = curr_res * 2
|
645 |
+
self.up.insert(0, up) # prepend to get consistent order
|
646 |
+
|
647 |
+
# end
|
648 |
+
self.norm_out = CausalNormalize(block_in, num_groups=num_groups)
|
649 |
+
self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
650 |
+
|
651 |
+
def unpatcher3d(self, x: torch.Tensor) -> torch.Tensor:
|
652 |
+
x, batch_size = time2batch(x)
|
653 |
+
x = self.unpatcher(x)
|
654 |
+
x = batch2time(x, batch_size)
|
655 |
+
|
656 |
+
return x
|
657 |
+
|
658 |
+
def forward(self, z):
|
659 |
+
h = self.conv_in(z)
|
660 |
+
|
661 |
+
# middle block.
|
662 |
+
h = self.mid.block_1(h)
|
663 |
+
h = self.mid.attn_1(h)
|
664 |
+
h = self.mid.block_2(h)
|
665 |
+
|
666 |
+
# decoder blocks.
|
667 |
+
for i_level in reversed(range(self.num_resolutions)):
|
668 |
+
for i_block in range(self.num_res_blocks + 1):
|
669 |
+
h = self.up[i_level].block[i_block](h)
|
670 |
+
if len(self.up[i_level].attn) > 0:
|
671 |
+
h = self.up[i_level].attn[i_block](h)
|
672 |
+
if i_level != 0:
|
673 |
+
h = self.up[i_level].upsample(h)
|
674 |
+
else:
|
675 |
+
# temporal upsample (last level)
|
676 |
+
time_factor = 1.0 + 1.0 * (h.shape[2] > 1)
|
677 |
+
if isinstance(time_factor, torch.Tensor):
|
678 |
+
time_factor = time_factor.item()
|
679 |
+
h = h.repeat_interleave(int(time_factor), dim=2)
|
680 |
+
h = h[..., int(time_factor - 1) :, :, :]
|
681 |
+
|
682 |
+
h = self.norm_out(h)
|
683 |
+
h = nonlinearity(h)
|
684 |
+
h = self.conv_out(h)
|
685 |
+
h = self.unpatcher3d(h)
|
686 |
+
return h
|
687 |
+
|
688 |
+
|
689 |
+
class EncoderFactorized(nn.Module):
|
690 |
+
def __init__(
|
691 |
+
self,
|
692 |
+
in_channels: int,
|
693 |
+
channels: int,
|
694 |
+
channels_mult: list[int],
|
695 |
+
num_res_blocks: int,
|
696 |
+
attn_resolutions: list[int],
|
697 |
+
dropout: float,
|
698 |
+
resolution: int,
|
699 |
+
z_channels: int,
|
700 |
+
spatial_compression: int = 16,
|
701 |
+
temporal_compression: int = 8,
|
702 |
+
**ignore_kwargs,
|
703 |
+
) -> None:
|
704 |
+
super().__init__()
|
705 |
+
self.num_resolutions = len(channels_mult)
|
706 |
+
self.num_res_blocks = num_res_blocks
|
707 |
+
|
708 |
+
# Patcher.
|
709 |
+
patch_size = ignore_kwargs.get("patch_size", 1)
|
710 |
+
self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
|
711 |
+
in_channels = in_channels * patch_size * patch_size * patch_size
|
712 |
+
|
713 |
+
# calculate the number of downsample operations
|
714 |
+
self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
|
715 |
+
assert (
|
716 |
+
self.num_spatial_downs <= self.num_resolutions
|
717 |
+
), f"Spatially downsample {self.num_resolutions} times at most"
|
718 |
+
|
719 |
+
self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size))
|
720 |
+
assert (
|
721 |
+
self.num_temporal_downs <= self.num_resolutions
|
722 |
+
), f"Temporally downsample {self.num_resolutions} times at most"
|
723 |
+
|
724 |
+
# downsampling
|
725 |
+
self.conv_in = nn.Sequential(
|
726 |
+
CausalConv3d(
|
727 |
+
in_channels,
|
728 |
+
channels,
|
729 |
+
kernel_size=(1, 3, 3),
|
730 |
+
stride=1,
|
731 |
+
padding=1,
|
732 |
+
),
|
733 |
+
CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0),
|
734 |
+
)
|
735 |
+
|
736 |
+
curr_res = resolution // patch_size
|
737 |
+
in_ch_mult = (1,) + tuple(channels_mult)
|
738 |
+
self.in_ch_mult = in_ch_mult
|
739 |
+
self.down = nn.ModuleList()
|
740 |
+
for i_level in range(self.num_resolutions):
|
741 |
+
block = nn.ModuleList()
|
742 |
+
attn = nn.ModuleList()
|
743 |
+
block_in = channels * in_ch_mult[i_level]
|
744 |
+
block_out = channels * channels_mult[i_level]
|
745 |
+
for _ in range(self.num_res_blocks):
|
746 |
+
block.append(
|
747 |
+
CausalResnetBlockFactorized3d(
|
748 |
+
in_channels=block_in,
|
749 |
+
out_channels=block_out,
|
750 |
+
dropout=dropout,
|
751 |
+
num_groups=1,
|
752 |
+
)
|
753 |
+
)
|
754 |
+
block_in = block_out
|
755 |
+
if curr_res in attn_resolutions:
|
756 |
+
attn.append(
|
757 |
+
nn.Sequential(
|
758 |
+
CausalAttnBlock(block_in, num_groups=1),
|
759 |
+
CausalTemporalAttnBlock(block_in, num_groups=1),
|
760 |
+
)
|
761 |
+
)
|
762 |
+
down = nn.Module()
|
763 |
+
down.block = block
|
764 |
+
down.attn = attn
|
765 |
+
if i_level != self.num_resolutions - 1:
|
766 |
+
spatial_down = i_level < self.num_spatial_downs
|
767 |
+
temporal_down = i_level < self.num_temporal_downs
|
768 |
+
down.downsample = CausalHybridDownsample3d(
|
769 |
+
block_in,
|
770 |
+
spatial_down=spatial_down,
|
771 |
+
temporal_down=temporal_down,
|
772 |
+
)
|
773 |
+
curr_res = curr_res // 2
|
774 |
+
self.down.append(down)
|
775 |
+
|
776 |
+
# middle
|
777 |
+
self.mid = nn.Module()
|
778 |
+
self.mid.block_1 = CausalResnetBlockFactorized3d(
|
779 |
+
in_channels=block_in,
|
780 |
+
out_channels=block_in,
|
781 |
+
dropout=dropout,
|
782 |
+
num_groups=1,
|
783 |
+
)
|
784 |
+
self.mid.attn_1 = nn.Sequential(
|
785 |
+
CausalAttnBlock(block_in, num_groups=1),
|
786 |
+
CausalTemporalAttnBlock(block_in, num_groups=1),
|
787 |
+
)
|
788 |
+
self.mid.block_2 = CausalResnetBlockFactorized3d(
|
789 |
+
in_channels=block_in,
|
790 |
+
out_channels=block_in,
|
791 |
+
dropout=dropout,
|
792 |
+
num_groups=1,
|
793 |
+
)
|
794 |
+
|
795 |
+
# end
|
796 |
+
self.norm_out = CausalNormalize(block_in, num_groups=1)
|
797 |
+
self.conv_out = nn.Sequential(
|
798 |
+
CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
|
799 |
+
CausalConv3d(
|
800 |
+
z_channels,
|
801 |
+
z_channels,
|
802 |
+
kernel_size=(3, 1, 1),
|
803 |
+
stride=1,
|
804 |
+
padding=0,
|
805 |
+
),
|
806 |
+
)
|
807 |
+
|
808 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
809 |
+
x = self.patcher3d(x)
|
810 |
+
|
811 |
+
# downsampling
|
812 |
+
hs = [self.conv_in(x)]
|
813 |
+
for i_level in range(self.num_resolutions):
|
814 |
+
for i_block in range(self.num_res_blocks):
|
815 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
816 |
+
if len(self.down[i_level].attn) > 0:
|
817 |
+
h = self.down[i_level].attn[i_block](h)
|
818 |
+
hs.append(h)
|
819 |
+
if i_level != self.num_resolutions - 1:
|
820 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
821 |
+
|
822 |
+
# middle
|
823 |
+
h = hs[-1]
|
824 |
+
h = self.mid.block_1(h)
|
825 |
+
h = self.mid.attn_1(h)
|
826 |
+
h = self.mid.block_2(h)
|
827 |
+
|
828 |
+
# end
|
829 |
+
h = self.norm_out(h)
|
830 |
+
h = nonlinearity(h)
|
831 |
+
h = self.conv_out(h)
|
832 |
+
return h
|
833 |
+
|
834 |
+
|
835 |
+
class DecoderFactorized(nn.Module):
|
836 |
+
def __init__(
|
837 |
+
self,
|
838 |
+
out_channels: int,
|
839 |
+
channels: int,
|
840 |
+
channels_mult: list[int],
|
841 |
+
num_res_blocks: int,
|
842 |
+
attn_resolutions: list[int],
|
843 |
+
dropout: float,
|
844 |
+
resolution: int,
|
845 |
+
z_channels: int,
|
846 |
+
spatial_compression: int = 16,
|
847 |
+
temporal_compression: int = 8,
|
848 |
+
**ignore_kwargs,
|
849 |
+
):
|
850 |
+
super().__init__()
|
851 |
+
self.num_resolutions = len(channels_mult)
|
852 |
+
self.num_res_blocks = num_res_blocks
|
853 |
+
|
854 |
+
# UnPatcher.
|
855 |
+
patch_size = ignore_kwargs.get("patch_size", 1)
|
856 |
+
self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
|
857 |
+
out_ch = out_channels * patch_size * patch_size * patch_size
|
858 |
+
|
859 |
+
# calculate the number of upsample operations
|
860 |
+
self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
|
861 |
+
assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most"
|
862 |
+
self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size))
|
863 |
+
assert (
|
864 |
+
self.num_temporal_ups <= self.num_resolutions
|
865 |
+
), f"Temporally upsample {self.num_resolutions} times at most"
|
866 |
+
|
867 |
+
block_in = channels * channels_mult[self.num_resolutions - 1]
|
868 |
+
curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
|
869 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
870 |
+
logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
871 |
+
|
872 |
+
# z to block_in
|
873 |
+
self.conv_in = nn.Sequential(
|
874 |
+
CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1),
|
875 |
+
CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0),
|
876 |
+
)
|
877 |
+
|
878 |
+
# middle
|
879 |
+
self.mid = nn.Module()
|
880 |
+
self.mid.block_1 = CausalResnetBlockFactorized3d(
|
881 |
+
in_channels=block_in,
|
882 |
+
out_channels=block_in,
|
883 |
+
dropout=dropout,
|
884 |
+
num_groups=1,
|
885 |
+
)
|
886 |
+
self.mid.attn_1 = nn.Sequential(
|
887 |
+
CausalAttnBlock(block_in, num_groups=1),
|
888 |
+
CausalTemporalAttnBlock(block_in, num_groups=1),
|
889 |
+
)
|
890 |
+
self.mid.block_2 = CausalResnetBlockFactorized3d(
|
891 |
+
in_channels=block_in,
|
892 |
+
out_channels=block_in,
|
893 |
+
dropout=dropout,
|
894 |
+
num_groups=1,
|
895 |
+
)
|
896 |
+
|
897 |
+
legacy_mode = ignore_kwargs.get("legacy_mode", False)
|
898 |
+
# upsampling
|
899 |
+
self.up = nn.ModuleList()
|
900 |
+
for i_level in reversed(range(self.num_resolutions)):
|
901 |
+
block = nn.ModuleList()
|
902 |
+
attn = nn.ModuleList()
|
903 |
+
block_out = channels * channels_mult[i_level]
|
904 |
+
for _ in range(self.num_res_blocks + 1):
|
905 |
+
block.append(
|
906 |
+
CausalResnetBlockFactorized3d(
|
907 |
+
in_channels=block_in,
|
908 |
+
out_channels=block_out,
|
909 |
+
dropout=dropout,
|
910 |
+
num_groups=1,
|
911 |
+
)
|
912 |
+
)
|
913 |
+
block_in = block_out
|
914 |
+
if curr_res in attn_resolutions:
|
915 |
+
attn.append(
|
916 |
+
nn.Sequential(
|
917 |
+
CausalAttnBlock(block_in, num_groups=1),
|
918 |
+
CausalTemporalAttnBlock(block_in, num_groups=1),
|
919 |
+
)
|
920 |
+
)
|
921 |
+
up = nn.Module()
|
922 |
+
up.block = block
|
923 |
+
up.attn = attn
|
924 |
+
if i_level != 0:
|
925 |
+
# The layer index for temporal/spatial downsampling performed
|
926 |
+
# in the encoder should correspond to the layer index in
|
927 |
+
# reverse order where upsampling is performed in the decoder.
|
928 |
+
# If you've a pre-trained model, you can simply finetune.
|
929 |
+
i_level_reverse = self.num_resolutions - i_level - 1
|
930 |
+
if legacy_mode:
|
931 |
+
temporal_up = i_level_reverse < self.num_temporal_ups
|
932 |
+
else:
|
933 |
+
temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1
|
934 |
+
spatial_up = temporal_up or (
|
935 |
+
i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups
|
936 |
+
)
|
937 |
+
up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up)
|
938 |
+
curr_res = curr_res * 2
|
939 |
+
self.up.insert(0, up) # prepend to get consistent order
|
940 |
+
|
941 |
+
# end
|
942 |
+
self.norm_out = CausalNormalize(block_in, num_groups=1)
|
943 |
+
self.conv_out = nn.Sequential(
|
944 |
+
CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1),
|
945 |
+
CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0),
|
946 |
+
)
|
947 |
+
|
948 |
+
def forward(self, z):
|
949 |
+
h = self.conv_in(z)
|
950 |
+
|
951 |
+
# middle block.
|
952 |
+
h = self.mid.block_1(h)
|
953 |
+
h = self.mid.attn_1(h)
|
954 |
+
h = self.mid.block_2(h)
|
955 |
+
|
956 |
+
# decoder blocks.
|
957 |
+
for i_level in reversed(range(self.num_resolutions)):
|
958 |
+
for i_block in range(self.num_res_blocks + 1):
|
959 |
+
h = self.up[i_level].block[i_block](h)
|
960 |
+
if len(self.up[i_level].attn) > 0:
|
961 |
+
h = self.up[i_level].attn[i_block](h)
|
962 |
+
if i_level != 0:
|
963 |
+
h = self.up[i_level].upsample(h)
|
964 |
+
|
965 |
+
h = self.norm_out(h)
|
966 |
+
h = nonlinearity(h)
|
967 |
+
h = self.conv_out(h)
|
968 |
+
h = self.unpatcher3d(h)
|
969 |
+
return h
|
cosmos_transfer1/auxiliary/tokenizer/modules/patching.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""The patcher and unpatcher implementation for 2D and 3D data.
|
17 |
+
|
18 |
+
The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
|
19 |
+
One on the rows and one on the columns.
|
20 |
+
For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
|
21 |
+
We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
|
22 |
+
For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
|
23 |
+
Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
|
24 |
+
as we need to support downsampling for more than 2x.
|
25 |
+
For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
|
26 |
+
[3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
|
27 |
+
"""
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
from einops import rearrange
|
32 |
+
|
33 |
+
_WAVELETS = {
|
34 |
+
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
|
35 |
+
"rearrange": torch.tensor([1.0, 1.0]),
|
36 |
+
}
|
37 |
+
_PERSISTENT = False
|
38 |
+
|
39 |
+
|
40 |
+
class Patcher(torch.nn.Module):
|
41 |
+
"""A module to convert image tensors into patches using torch operations.
|
42 |
+
|
43 |
+
The main difference from `class Patching` is that this module implements
|
44 |
+
all operations using torch, rather than python or numpy, for efficiency purpose.
|
45 |
+
|
46 |
+
It's bit-wise identical to the Patching module outputs, with the added
|
47 |
+
benefit of being torch.jit scriptable.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, patch_size=1, patch_method="haar"):
|
51 |
+
super().__init__()
|
52 |
+
self.patch_size = patch_size
|
53 |
+
self.patch_method = patch_method
|
54 |
+
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT)
|
55 |
+
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
56 |
+
self.register_buffer(
|
57 |
+
"_arange",
|
58 |
+
torch.arange(_WAVELETS[patch_method].shape[0]),
|
59 |
+
persistent=_PERSISTENT,
|
60 |
+
)
|
61 |
+
for param in self.parameters():
|
62 |
+
param.requires_grad = False
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
if self.patch_method == "haar":
|
66 |
+
return self._haar(x)
|
67 |
+
elif self.patch_method == "rearrange":
|
68 |
+
return self._arrange(x)
|
69 |
+
else:
|
70 |
+
raise ValueError("Unknown patch method: " + self.patch_method)
|
71 |
+
|
72 |
+
def _dwt(self, x, mode="reflect", rescale=False):
|
73 |
+
dtype = x.dtype
|
74 |
+
h = self.wavelets
|
75 |
+
|
76 |
+
n = h.shape[0]
|
77 |
+
g = x.shape[1]
|
78 |
+
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
79 |
+
hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
|
80 |
+
hh = hh.to(dtype=dtype)
|
81 |
+
hl = hl.to(dtype=dtype)
|
82 |
+
|
83 |
+
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
|
84 |
+
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
|
85 |
+
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
|
86 |
+
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
87 |
+
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
88 |
+
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
|
89 |
+
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
|
90 |
+
|
91 |
+
out = torch.cat([xll, xlh, xhl, xhh], dim=1)
|
92 |
+
if rescale:
|
93 |
+
out = out / 2
|
94 |
+
return out
|
95 |
+
|
96 |
+
def _haar(self, x):
|
97 |
+
for _ in self.range:
|
98 |
+
x = self._dwt(x, rescale=True)
|
99 |
+
return x
|
100 |
+
|
101 |
+
def _arrange(self, x):
|
102 |
+
x = rearrange(
|
103 |
+
x,
|
104 |
+
"b c (h p1) (w p2) -> b (c p1 p2) h w",
|
105 |
+
p1=self.patch_size,
|
106 |
+
p2=self.patch_size,
|
107 |
+
).contiguous()
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
class Patcher3D(Patcher):
|
112 |
+
"""A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
|
113 |
+
|
114 |
+
def __init__(self, patch_size=1, patch_method="haar"):
|
115 |
+
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
116 |
+
self.register_buffer(
|
117 |
+
"patch_size_buffer",
|
118 |
+
patch_size * torch.ones([1], dtype=torch.int32),
|
119 |
+
persistent=_PERSISTENT,
|
120 |
+
)
|
121 |
+
|
122 |
+
def _dwt(self, x, wavelet, mode="reflect", rescale=False):
|
123 |
+
dtype = x.dtype
|
124 |
+
h = self.wavelets
|
125 |
+
|
126 |
+
n = h.shape[0]
|
127 |
+
g = x.shape[1]
|
128 |
+
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
|
129 |
+
hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
|
130 |
+
hh = hh.to(dtype=dtype)
|
131 |
+
hl = hl.to(dtype=dtype)
|
132 |
+
|
133 |
+
# Handles temporal axis.
|
134 |
+
x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
|
135 |
+
xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
136 |
+
xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
137 |
+
|
138 |
+
# Handles spatial axes.
|
139 |
+
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
140 |
+
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
141 |
+
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
142 |
+
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
143 |
+
|
144 |
+
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
145 |
+
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
146 |
+
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
147 |
+
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
148 |
+
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
149 |
+
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
150 |
+
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
151 |
+
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
152 |
+
|
153 |
+
out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
|
154 |
+
if rescale:
|
155 |
+
out = out / (2 * torch.sqrt(torch.tensor(2.0)))
|
156 |
+
return out
|
157 |
+
|
158 |
+
def _haar(self, x):
|
159 |
+
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
160 |
+
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
161 |
+
for _ in self.range:
|
162 |
+
x = self._dwt(x, "haar", rescale=True)
|
163 |
+
return x
|
164 |
+
|
165 |
+
def _arrange(self, x):
|
166 |
+
xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
|
167 |
+
x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
|
168 |
+
x = rearrange(
|
169 |
+
x,
|
170 |
+
"b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
|
171 |
+
p1=self.patch_size,
|
172 |
+
p2=self.patch_size,
|
173 |
+
p3=self.patch_size,
|
174 |
+
).contiguous()
|
175 |
+
return x
|
176 |
+
|
177 |
+
|
178 |
+
class UnPatcher(torch.nn.Module):
|
179 |
+
"""A module to convert patches into image tensorsusing torch operations.
|
180 |
+
|
181 |
+
The main difference from `class Unpatching` is that this module implements
|
182 |
+
all operations using torch, rather than python or numpy, for efficiency purpose.
|
183 |
+
|
184 |
+
It's bit-wise identical to the Unpatching module outputs, with the added
|
185 |
+
benefit of being torch.jit scriptable.
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self, patch_size=1, patch_method="haar"):
|
189 |
+
super().__init__()
|
190 |
+
self.patch_size = patch_size
|
191 |
+
self.patch_method = patch_method
|
192 |
+
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT)
|
193 |
+
self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
|
194 |
+
self.register_buffer(
|
195 |
+
"_arange",
|
196 |
+
torch.arange(_WAVELETS[patch_method].shape[0]),
|
197 |
+
persistent=_PERSISTENT,
|
198 |
+
)
|
199 |
+
for param in self.parameters():
|
200 |
+
param.requires_grad = False
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
if self.patch_method == "haar":
|
204 |
+
return self._ihaar(x)
|
205 |
+
elif self.patch_method == "rearrange":
|
206 |
+
return self._iarrange(x)
|
207 |
+
else:
|
208 |
+
raise ValueError("Unknown patch method: " + self.patch_method)
|
209 |
+
|
210 |
+
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
211 |
+
dtype = x.dtype
|
212 |
+
h = self.wavelets
|
213 |
+
n = h.shape[0]
|
214 |
+
|
215 |
+
g = x.shape[1] // 4
|
216 |
+
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
217 |
+
hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
|
218 |
+
hh = hh.to(dtype=dtype)
|
219 |
+
hl = hl.to(dtype=dtype)
|
220 |
+
|
221 |
+
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
|
222 |
+
|
223 |
+
# Inverse transform.
|
224 |
+
yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
|
225 |
+
yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
|
226 |
+
yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
|
227 |
+
yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
|
228 |
+
y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2))
|
229 |
+
y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2))
|
230 |
+
|
231 |
+
if rescale:
|
232 |
+
y = y * 2
|
233 |
+
return y
|
234 |
+
|
235 |
+
def _ihaar(self, x):
|
236 |
+
for _ in self.range:
|
237 |
+
x = self._idwt(x, "haar", rescale=True)
|
238 |
+
return x
|
239 |
+
|
240 |
+
def _iarrange(self, x):
|
241 |
+
x = rearrange(
|
242 |
+
x,
|
243 |
+
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
244 |
+
p1=self.patch_size,
|
245 |
+
p2=self.patch_size,
|
246 |
+
)
|
247 |
+
return x
|
248 |
+
|
249 |
+
|
250 |
+
class UnPatcher3D(UnPatcher):
|
251 |
+
"""A 3D inverse discrete wavelet transform for video wavelet decompositions."""
|
252 |
+
|
253 |
+
def __init__(self, patch_size=1, patch_method="haar"):
|
254 |
+
super().__init__(patch_method=patch_method, patch_size=patch_size)
|
255 |
+
|
256 |
+
def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
|
257 |
+
dtype = x.dtype
|
258 |
+
h = self.wavelets
|
259 |
+
n = h.shape[0]
|
260 |
+
|
261 |
+
g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
|
262 |
+
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
|
263 |
+
hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
|
264 |
+
hl = hl.to(dtype=dtype)
|
265 |
+
hh = hh.to(dtype=dtype)
|
266 |
+
|
267 |
+
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
|
268 |
+
|
269 |
+
# Height height transposed convolutions.
|
270 |
+
xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
271 |
+
xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
272 |
+
|
273 |
+
xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
274 |
+
xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
275 |
+
|
276 |
+
xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
277 |
+
xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
278 |
+
|
279 |
+
xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
280 |
+
xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
|
281 |
+
|
282 |
+
# Handles width transposed convolutions.
|
283 |
+
xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
284 |
+
xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
285 |
+
xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
286 |
+
xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
|
287 |
+
|
288 |
+
# Handles time axis transposed convolutions.
|
289 |
+
x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
290 |
+
x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
|
291 |
+
|
292 |
+
if rescale:
|
293 |
+
x = x * (2 * torch.sqrt(torch.tensor(2.0)))
|
294 |
+
return x
|
295 |
+
|
296 |
+
def _ihaar(self, x):
|
297 |
+
for _ in self.range:
|
298 |
+
x = self._idwt(x, "haar", rescale=True)
|
299 |
+
x = x[:, :, self.patch_size - 1 :, ...]
|
300 |
+
return x
|
301 |
+
|
302 |
+
def _iarrange(self, x):
|
303 |
+
x = rearrange(
|
304 |
+
x,
|
305 |
+
"b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
|
306 |
+
p1=self.patch_size,
|
307 |
+
p2=self.patch_size,
|
308 |
+
p3=self.patch_size,
|
309 |
+
)
|
310 |
+
x = x[:, :, self.patch_size - 1 :, ...]
|
311 |
+
return x
|
cosmos_transfer1/auxiliary/tokenizer/modules/quantizers.py
ADDED
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Quantizers for discrete image and video tokenization."""
|
17 |
+
|
18 |
+
from typing import Optional
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from einops import reduce
|
25 |
+
from loguru import logger as logging
|
26 |
+
|
27 |
+
from cosmos_transfer1.auxiliary.tokenizer.modules.utils import (
|
28 |
+
default,
|
29 |
+
entropy,
|
30 |
+
pack_one,
|
31 |
+
rearrange,
|
32 |
+
round_ste,
|
33 |
+
unpack_one,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
class ResidualFSQuantizer(nn.Module):
|
38 |
+
"""Residual Finite Scalar Quantization
|
39 |
+
|
40 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, levels: list[int], num_quantizers: int, **ignore_kwargs):
|
44 |
+
super().__init__()
|
45 |
+
self.dtype = ignore_kwargs.get("dtype", torch.float32)
|
46 |
+
self.layers = nn.ModuleList([FSQuantizer(levels=levels) for _ in range(num_quantizers)])
|
47 |
+
|
48 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
49 |
+
indices_stack = []
|
50 |
+
residual = x
|
51 |
+
quantized_out = 0
|
52 |
+
loss_out = 0
|
53 |
+
for i, layer in enumerate(self.layers):
|
54 |
+
quant_indices, z, loss = layer(residual)
|
55 |
+
indices_stack.append(quant_indices)
|
56 |
+
residual = residual - z.detach()
|
57 |
+
quantized_out = quantized_out + z
|
58 |
+
loss_out = loss_out + loss
|
59 |
+
self.residual = residual
|
60 |
+
indices = torch.stack(indices_stack, dim=1)
|
61 |
+
return indices, quantized_out.to(self.dtype), loss_out.to(self.dtype)
|
62 |
+
|
63 |
+
def indices_to_codes(self, indices_stack: torch.Tensor) -> torch.Tensor:
|
64 |
+
quantized_out = 0
|
65 |
+
for layer, indices in zip(self.layers, indices_stack.transpose(0, 1)):
|
66 |
+
quantized_out += layer.indices_to_codes(indices)
|
67 |
+
return quantized_out
|
68 |
+
|
69 |
+
|
70 |
+
class FSQuantizer(nn.Module):
|
71 |
+
"""Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
72 |
+
|
73 |
+
Code adapted from Jax version in Appendix A.1.
|
74 |
+
|
75 |
+
Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/
|
76 |
+
vector_quantize_pytorch/finite_scalar_quantization.py
|
77 |
+
[Copyright (c) 2020 Phil Wang]
|
78 |
+
https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
levels: list[int],
|
84 |
+
dim: Optional[int] = None,
|
85 |
+
num_codebooks=1,
|
86 |
+
keep_num_codebooks_dim: Optional[bool] = None,
|
87 |
+
scale: Optional[float] = None,
|
88 |
+
**ignore_kwargs,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
self.dtype = ignore_kwargs.get("dtype", torch.bfloat16)
|
92 |
+
_levels = torch.tensor(levels, dtype=torch.int32)
|
93 |
+
self.register_buffer("_levels", _levels, persistent=False)
|
94 |
+
|
95 |
+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
|
96 |
+
self.register_buffer("_basis", _basis, persistent=False)
|
97 |
+
|
98 |
+
self.scale = scale
|
99 |
+
|
100 |
+
codebook_dim = len(levels)
|
101 |
+
self.codebook_dim = codebook_dim
|
102 |
+
|
103 |
+
effective_codebook_dim = codebook_dim * num_codebooks
|
104 |
+
self.num_codebooks = num_codebooks
|
105 |
+
self.effective_codebook_dim = effective_codebook_dim
|
106 |
+
|
107 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
108 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
109 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
110 |
+
|
111 |
+
self.dim = default(dim, len(_levels) * num_codebooks)
|
112 |
+
|
113 |
+
has_projections = self.dim != effective_codebook_dim
|
114 |
+
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
|
115 |
+
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
|
116 |
+
self.has_projections = has_projections
|
117 |
+
|
118 |
+
self.codebook_size = self._levels.prod().item()
|
119 |
+
|
120 |
+
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
|
121 |
+
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
|
122 |
+
|
123 |
+
def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
|
124 |
+
"""Bound `z`, an array of shape (..., d)."""
|
125 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
126 |
+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
127 |
+
shift = (offset / half_l).atanh()
|
128 |
+
return (z + shift).tanh() * half_l - offset
|
129 |
+
|
130 |
+
def quantize(self, z: torch.Tensor) -> torch.Tensor:
|
131 |
+
"""Quantizes z, returns quantized zhat, same shape as z."""
|
132 |
+
quantized = round_ste(self.bound(z))
|
133 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
134 |
+
return quantized / half_width
|
135 |
+
|
136 |
+
def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
|
137 |
+
half_width = self._levels // 2
|
138 |
+
return (zhat_normalized * half_width) + half_width
|
139 |
+
|
140 |
+
def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
|
141 |
+
half_width = self._levels // 2
|
142 |
+
return (zhat - half_width) / half_width
|
143 |
+
|
144 |
+
def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
|
145 |
+
"""Converts a `code` to an index in the codebook."""
|
146 |
+
assert zhat.shape[-1] == self.codebook_dim
|
147 |
+
zhat = self._scale_and_shift(zhat).float()
|
148 |
+
return (zhat * self._basis).sum(dim=-1).to(torch.int32)
|
149 |
+
|
150 |
+
def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor:
|
151 |
+
"""Inverse of `codes_to_indices`."""
|
152 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
153 |
+
indices = rearrange(indices, "... -> ... 1")
|
154 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
155 |
+
codes = self._scale_and_shift_inverse(codes_non_centered)
|
156 |
+
|
157 |
+
if self.keep_num_codebooks_dim:
|
158 |
+
codes = rearrange(codes, "... c d -> ... (c d)")
|
159 |
+
|
160 |
+
if project_out:
|
161 |
+
codes = self.project_out(codes)
|
162 |
+
|
163 |
+
if is_img_or_video:
|
164 |
+
codes = rearrange(codes, "b ... d -> b d ...")
|
165 |
+
|
166 |
+
return codes.to(self.dtype)
|
167 |
+
|
168 |
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
169 |
+
"""
|
170 |
+
einstein notation
|
171 |
+
b - batch
|
172 |
+
n - sequence (or flattened spatial dimensions)
|
173 |
+
d - feature dimension, which is also log2(codebook size)
|
174 |
+
c - number of codebook dim
|
175 |
+
"""
|
176 |
+
is_img_or_video = z.ndim >= 4
|
177 |
+
|
178 |
+
# standardize image or video into (batch, seq, dimension)
|
179 |
+
|
180 |
+
if is_img_or_video:
|
181 |
+
z = rearrange(z, "b d ... -> b ... d")
|
182 |
+
z, ps = pack_one(z, "b * d")
|
183 |
+
|
184 |
+
assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
185 |
+
|
186 |
+
z = self.project_in(z)
|
187 |
+
|
188 |
+
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
|
189 |
+
|
190 |
+
codes = self.quantize(z)
|
191 |
+
indices = self.codes_to_indices(codes)
|
192 |
+
|
193 |
+
codes = rearrange(codes, "b n c d -> b n (c d)")
|
194 |
+
|
195 |
+
out = self.project_out(codes)
|
196 |
+
|
197 |
+
# reconstitute image or video dimensions
|
198 |
+
|
199 |
+
if is_img_or_video:
|
200 |
+
out = unpack_one(out, ps, "b * d")
|
201 |
+
out = rearrange(out, "b ... d -> b d ...")
|
202 |
+
indices = unpack_one(indices, ps, "b * c")
|
203 |
+
dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True))
|
204 |
+
else:
|
205 |
+
dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1)
|
206 |
+
|
207 |
+
if not self.keep_num_codebooks_dim:
|
208 |
+
indices = rearrange(indices, "... 1 -> ...")
|
209 |
+
|
210 |
+
return (indices, out.to(self.dtype), dummy_loss)
|
211 |
+
|
212 |
+
|
213 |
+
class VectorQuantizer(nn.Module):
|
214 |
+
"""Improved version over VectorQuantizer. Mostly
|
215 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
216 |
+
|
217 |
+
Adapted from: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/
|
218 |
+
taming/modules/vqvae/quantize.py
|
219 |
+
|
220 |
+
[Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer]
|
221 |
+
https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/License.txt
|
222 |
+
"""
|
223 |
+
|
224 |
+
def __init__(
|
225 |
+
self,
|
226 |
+
num_embeddings: int,
|
227 |
+
embedding_dim: int,
|
228 |
+
beta: float = 0.25,
|
229 |
+
remap: str = None,
|
230 |
+
unknown_index: str = "random",
|
231 |
+
sane_index_shape: bool = False,
|
232 |
+
legacy: bool = True,
|
233 |
+
use_norm=False,
|
234 |
+
**ignore_kwargs,
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
self.n_e = num_embeddings
|
238 |
+
self.e_dim = embedding_dim
|
239 |
+
self.beta = beta
|
240 |
+
self.legacy = legacy
|
241 |
+
self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x
|
242 |
+
|
243 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
244 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
245 |
+
|
246 |
+
self.remap = remap
|
247 |
+
if self.remap is not None:
|
248 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
249 |
+
self.re_embed = self.used.shape[0]
|
250 |
+
self.unknown_index = unknown_index
|
251 |
+
if self.unknown_index == "extra":
|
252 |
+
self.unknown_index = self.re_embed
|
253 |
+
self.re_embed = self.re_embed + 1
|
254 |
+
print(
|
255 |
+
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
256 |
+
f"Using {self.unknown_index} for unknown indices."
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
self.re_embed = num_embeddings
|
260 |
+
|
261 |
+
self.sane_index_shape = sane_index_shape
|
262 |
+
self.dtype = ignore_kwargs.get("dtype", torch.float32)
|
263 |
+
|
264 |
+
def remap_to_used(self, inds):
|
265 |
+
ishape = inds.shape
|
266 |
+
assert len(ishape) > 1
|
267 |
+
inds = inds.reshape(ishape[0], -1)
|
268 |
+
used = self.used.to(inds)
|
269 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
270 |
+
new = match.argmax(-1)
|
271 |
+
unknown = match.sum(2) < 1
|
272 |
+
if self.unknown_index == "random":
|
273 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
274 |
+
else:
|
275 |
+
new[unknown] = self.unknown_index
|
276 |
+
return new.reshape(ishape)
|
277 |
+
|
278 |
+
def unmap_to_all(self, inds):
|
279 |
+
ishape = inds.shape
|
280 |
+
assert len(ishape) > 1
|
281 |
+
inds = inds.reshape(ishape[0], -1)
|
282 |
+
used = self.used.to(inds)
|
283 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
284 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
285 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
286 |
+
return back.reshape(ishape)
|
287 |
+
|
288 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
289 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
290 |
+
assert rescale_logits is False, "Only for interface compatible with Gumbel"
|
291 |
+
assert return_logits is False, "Only for interface compatible with Gumbel"
|
292 |
+
z = rearrange(z, "b c h w -> b h w c").contiguous()
|
293 |
+
z_flattened = z.view(-1, self.e_dim)
|
294 |
+
|
295 |
+
d = (
|
296 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
297 |
+
+ torch.sum(self.embedding.weight**2, dim=1)
|
298 |
+
- 2
|
299 |
+
* torch.einsum(
|
300 |
+
"bd,dn->bn",
|
301 |
+
z_flattened,
|
302 |
+
rearrange(self.embedding.weight, "n d -> d n"),
|
303 |
+
)
|
304 |
+
)
|
305 |
+
|
306 |
+
encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
307 |
+
encodings = torch.zeros(encoding_indices.shape[0], self.n_e, device=z.device)
|
308 |
+
encodings.scatter_(1, encoding_indices, 1)
|
309 |
+
z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape)
|
310 |
+
min_encodings = None
|
311 |
+
|
312 |
+
z_q, z = self.norm(z_q), self.norm(z)
|
313 |
+
|
314 |
+
# compute loss for embedding
|
315 |
+
commit_loss = torch.mean((z_q - z.detach()) ** 2, dim=[1, 2, 3], keepdim=True)
|
316 |
+
emb_loss = torch.mean((z_q.detach() - z) ** 2, dim=[1, 2, 3], keepdim=True)
|
317 |
+
if not self.legacy:
|
318 |
+
loss = self.beta * emb_loss + commit_loss
|
319 |
+
else:
|
320 |
+
loss = emb_loss + self.beta * commit_loss
|
321 |
+
|
322 |
+
# preserve gradients
|
323 |
+
z_q = z + (z_q - z).detach()
|
324 |
+
avg_probs = torch.mean(encodings, dim=0)
|
325 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
326 |
+
|
327 |
+
# reshape back to match original input shape
|
328 |
+
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
|
329 |
+
|
330 |
+
if self.remap is not None:
|
331 |
+
min_encoding_indices = encoding_indices.squeeze(1).reshape(z.shape[0], -1) # add batch axis
|
332 |
+
min_encoding_indices = self.remap_to_used(encoding_indices.squeeze(1))
|
333 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
334 |
+
|
335 |
+
if self.sane_index_shape:
|
336 |
+
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
337 |
+
|
338 |
+
# TODO: return (indices, z_q, loss)
|
339 |
+
return (
|
340 |
+
z_q,
|
341 |
+
loss,
|
342 |
+
(
|
343 |
+
encoding_indices.squeeze(1),
|
344 |
+
min_encodings,
|
345 |
+
commit_loss.mean().detach(),
|
346 |
+
self.beta * emb_loss.mean().detach(),
|
347 |
+
perplexity.mean().detach(),
|
348 |
+
),
|
349 |
+
)
|
350 |
+
|
351 |
+
def get_codebook_entry(self, indices, shape):
|
352 |
+
# shape specifying (batch, height, width, channel)
|
353 |
+
if self.remap is not None:
|
354 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
355 |
+
indices = self.unmap_to_all(indices)
|
356 |
+
indices = indices.reshape(-1) # flatten again
|
357 |
+
|
358 |
+
# get quantized latent vectors
|
359 |
+
z_q = self.embedding(indices)
|
360 |
+
|
361 |
+
if shape is not None:
|
362 |
+
z_q = z_q.view(shape)
|
363 |
+
# reshape back to match original input shape
|
364 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
365 |
+
|
366 |
+
return z_q
|
367 |
+
|
368 |
+
|
369 |
+
class LFQuantizer(nn.Module):
|
370 |
+
"""Lookup-Free Quantization
|
371 |
+
|
372 |
+
Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/
|
373 |
+
vector_quantize_pytorch/lookup_free_quantization.py
|
374 |
+
[Copyright (c) 2020 Phil Wang]
|
375 |
+
https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE
|
376 |
+
"""
|
377 |
+
|
378 |
+
def __init__(
|
379 |
+
self,
|
380 |
+
*,
|
381 |
+
codebook_size: int,
|
382 |
+
codebook_dim: int,
|
383 |
+
embed_dim: Optional[int] = None, # if None, use codebook_dim
|
384 |
+
entropy_loss_weight=0.1,
|
385 |
+
commitment_loss_weight=0.25,
|
386 |
+
default_temp: float = 0.01,
|
387 |
+
entropy_loss: bool = False,
|
388 |
+
**ignore_kwargs,
|
389 |
+
):
|
390 |
+
"""Lookup-Free Quantization
|
391 |
+
|
392 |
+
Args:
|
393 |
+
codebook_size (int): The number of entries in the codebook.
|
394 |
+
codebook_dim (int): The number of bits in each code.
|
395 |
+
embed_dim (Optional[int], optional): The dimension of the input embedding. Defaults to None.
|
396 |
+
entropy_loss_weight (float, optional): Whether to use entropy loss. Defaults to 0.1.
|
397 |
+
commitment_loss_weight (float, optional): Weight for commitment loss. Defaults to 0.25.
|
398 |
+
default_temp (float, optional): The temprature to use. Defaults to 0.01.
|
399 |
+
entropy_loss (bool, optional): Flag for entropy loss. Defaults to False.
|
400 |
+
"""
|
401 |
+
super().__init__()
|
402 |
+
self.entropy_loss = entropy_loss
|
403 |
+
self.codebook_dim = codebook_dim
|
404 |
+
self.default_temp = default_temp
|
405 |
+
self.entrop_loss_weight = entropy_loss_weight
|
406 |
+
self.commitment_loss_weight = commitment_loss_weight
|
407 |
+
embed_dim = embed_dim or codebook_dim
|
408 |
+
|
409 |
+
has_projections = embed_dim != codebook_dim
|
410 |
+
self.project_in = nn.Linear(embed_dim, codebook_dim) if has_projections else nn.Identity()
|
411 |
+
self.project_out = nn.Linear(codebook_dim, embed_dim) if has_projections else nn.Identity()
|
412 |
+
logging.info(f"LFQ: has_projections={has_projections}, dim_in={embed_dim}, codebook_dim={codebook_dim}")
|
413 |
+
|
414 |
+
self.dtype = ignore_kwargs.get("dtype", torch.float32)
|
415 |
+
|
416 |
+
if entropy_loss:
|
417 |
+
assert 2**codebook_dim == codebook_size, "codebook size must be 2 ** codebook_dim"
|
418 |
+
self.codebook_size = codebook_size
|
419 |
+
|
420 |
+
self.register_buffer(
|
421 |
+
"mask",
|
422 |
+
2 ** torch.arange(codebook_dim - 1, -1, -1),
|
423 |
+
persistent=False,
|
424 |
+
)
|
425 |
+
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
|
426 |
+
|
427 |
+
all_codes = torch.arange(codebook_size)
|
428 |
+
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
|
429 |
+
codebook = 2 * bits - 1.0
|
430 |
+
|
431 |
+
self.register_buffer("codebook", codebook, persistent=False) # [codebook_size, codebook_dim]
|
432 |
+
|
433 |
+
def forward(self, z: torch.Tensor, temp: float = None) -> torch.Tensor:
|
434 |
+
temp = temp or self.default_temp
|
435 |
+
|
436 |
+
z = rearrange(z, "b d ... -> b ... d")
|
437 |
+
z, ps = pack_one(z, "b * d")
|
438 |
+
z = self.project_in(z)
|
439 |
+
|
440 |
+
# split out number of codebooks
|
441 |
+
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
|
442 |
+
|
443 |
+
# quantization
|
444 |
+
original_input = z
|
445 |
+
|
446 |
+
codebook_value = torch.ones_like(z)
|
447 |
+
z_q = torch.where(z > 0, codebook_value, -codebook_value)
|
448 |
+
|
449 |
+
# preserve gradients
|
450 |
+
z_q = z + (z_q - z).detach()
|
451 |
+
|
452 |
+
# commit loss
|
453 |
+
commit_loss = ((original_input - z_q.detach()) ** 2).mean(dim=[1, 2, 3])
|
454 |
+
|
455 |
+
z_q = rearrange(z_q, "b n c d -> b n (c d)")
|
456 |
+
z_q = self.project_out(z_q)
|
457 |
+
|
458 |
+
# reshape
|
459 |
+
z_q = unpack_one(z_q, ps, "b * d")
|
460 |
+
z_q = rearrange(z_q, "b ... d -> b d ...")
|
461 |
+
|
462 |
+
loss = self.commitment_loss_weight * commit_loss
|
463 |
+
|
464 |
+
# entropy loss (eq-5)
|
465 |
+
if self.entropy_loss:
|
466 |
+
# indices
|
467 |
+
indices = reduce((z > 0).int() * self.mask.int(), "b n c d -> b n c", "sum")
|
468 |
+
indices = unpack_one(indices, ps, "b * c")
|
469 |
+
indices = rearrange(indices, "... 1 -> ...")
|
470 |
+
|
471 |
+
distance = -2 * torch.einsum(
|
472 |
+
"... i d, j d -> ... i j",
|
473 |
+
original_input,
|
474 |
+
self.codebook.to(original_input.dtype),
|
475 |
+
)
|
476 |
+
prob = (-distance / temp).softmax(dim=-1)
|
477 |
+
per_sample_entropy = entropy(prob).mean(dim=[1, 2])
|
478 |
+
avg_prob = reduce(prob, "... c d -> c d", "mean")
|
479 |
+
codebook_entropy = entropy(avg_prob).mean()
|
480 |
+
entropy_aux_loss = per_sample_entropy - codebook_entropy
|
481 |
+
|
482 |
+
loss += self.entrop_loss_weight * entropy_aux_loss
|
483 |
+
|
484 |
+
# TODO: return (indices, z_q, loss)
|
485 |
+
return (
|
486 |
+
z_q,
|
487 |
+
loss.unsqueeze(1).unsqueeze(1).unsqueeze(1),
|
488 |
+
(
|
489 |
+
indices,
|
490 |
+
self.commitment_loss_weight * commit_loss.mean().detach(),
|
491 |
+
self.entrop_loss_weight * entropy_aux_loss.mean().detach(),
|
492 |
+
self.entrop_loss_weight * per_sample_entropy.mean().detach(),
|
493 |
+
self.entrop_loss_weight * codebook_entropy.mean().detach(),
|
494 |
+
),
|
495 |
+
)
|
496 |
+
else:
|
497 |
+
return (
|
498 |
+
z_q,
|
499 |
+
loss.unsqueeze(1).unsqueeze(1).unsqueeze(1),
|
500 |
+
self.commitment_loss_weight * commit_loss.mean().detach(),
|
501 |
+
)
|
502 |
+
|
503 |
+
|
504 |
+
class InvQuantizerJit(nn.Module):
|
505 |
+
"""Use for decoder_jit to trace quantizer in discrete tokenizer"""
|
506 |
+
|
507 |
+
def __init__(self, quantizer):
|
508 |
+
super().__init__()
|
509 |
+
self.quantizer = quantizer
|
510 |
+
|
511 |
+
def forward(self, indices: torch.Tensor):
|
512 |
+
codes = self.quantizer.indices_to_codes(indices)
|
513 |
+
return codes.to(self.quantizer.dtype)
|
cosmos_transfer1/auxiliary/tokenizer/modules/utils.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Shared utilities for the networks module."""
|
17 |
+
|
18 |
+
from typing import Any
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from einops import pack, rearrange, unpack
|
22 |
+
|
23 |
+
|
24 |
+
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
25 |
+
batch_size = x.shape[0]
|
26 |
+
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
|
27 |
+
|
28 |
+
|
29 |
+
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
|
30 |
+
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
|
31 |
+
|
32 |
+
|
33 |
+
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
|
34 |
+
batch_size, height = x.shape[0], x.shape[-2]
|
35 |
+
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
|
36 |
+
|
37 |
+
|
38 |
+
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
|
39 |
+
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
|
40 |
+
|
41 |
+
|
42 |
+
def cast_tuple(t: Any, length: int = 1) -> Any:
|
43 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
44 |
+
|
45 |
+
|
46 |
+
def replication_pad(x):
|
47 |
+
return torch.cat([x[:, :, :1, ...], x], dim=2)
|
48 |
+
|
49 |
+
|
50 |
+
def divisible_by(num: int, den: int) -> bool:
|
51 |
+
return (num % den) == 0
|
52 |
+
|
53 |
+
|
54 |
+
def is_odd(n: int) -> bool:
|
55 |
+
return not divisible_by(n, 2)
|
56 |
+
|
57 |
+
|
58 |
+
def nonlinearity(x):
|
59 |
+
return x * torch.sigmoid(x)
|
60 |
+
|
61 |
+
|
62 |
+
def Normalize(in_channels, num_groups=32):
|
63 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
64 |
+
|
65 |
+
|
66 |
+
class CausalNormalize(torch.nn.Module):
|
67 |
+
def __init__(self, in_channels, num_groups=1):
|
68 |
+
super().__init__()
|
69 |
+
self.norm = torch.nn.GroupNorm(
|
70 |
+
num_groups=num_groups,
|
71 |
+
num_channels=in_channels,
|
72 |
+
eps=1e-6,
|
73 |
+
affine=True,
|
74 |
+
)
|
75 |
+
self.num_groups = num_groups
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
|
79 |
+
# All new models should use num_groups=1, otherwise causality is not guaranteed.
|
80 |
+
if self.num_groups == 1:
|
81 |
+
x, batch_size = time2batch(x)
|
82 |
+
return batch2time(self.norm(x), batch_size)
|
83 |
+
return self.norm(x)
|
84 |
+
|
85 |
+
|
86 |
+
def exists(v):
|
87 |
+
return v is not None
|
88 |
+
|
89 |
+
|
90 |
+
def default(*args):
|
91 |
+
for arg in args:
|
92 |
+
if exists(arg):
|
93 |
+
return arg
|
94 |
+
return None
|
95 |
+
|
96 |
+
|
97 |
+
def pack_one(t, pattern):
|
98 |
+
return pack([t], pattern)
|
99 |
+
|
100 |
+
|
101 |
+
def unpack_one(t, ps, pattern):
|
102 |
+
return unpack(t, ps, pattern)[0]
|
103 |
+
|
104 |
+
|
105 |
+
def round_ste(z: torch.Tensor) -> torch.Tensor:
|
106 |
+
"""Round with straight through gradients."""
|
107 |
+
zhat = z.round()
|
108 |
+
return z + (zhat - z).detach()
|
109 |
+
|
110 |
+
|
111 |
+
def log(t, eps=1e-5):
|
112 |
+
return t.clamp(min=eps).log()
|
113 |
+
|
114 |
+
|
115 |
+
def entropy(prob):
|
116 |
+
return (-prob * log(prob)).sum(dim=-1)
|
cosmos_transfer1/auxiliary/tokenizer/networks/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from enum import Enum
|
17 |
+
|
18 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks.configs import continuous_image as continuous_image_dict
|
19 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks.configs import continuous_video as continuous_video_dict
|
20 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks.configs import discrete_image as discrete_image_dict
|
21 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks.configs import discrete_video as discrete_video_dict
|
22 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks.continuous_image import ContinuousImageTokenizer
|
23 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks.continuous_video import CausalContinuousVideoTokenizer
|
24 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks.discrete_image import DiscreteImageTokenizer
|
25 |
+
from cosmos_transfer1.auxiliary.tokenizer.networks.discrete_video import CausalDiscreteVideoTokenizer
|
26 |
+
|
27 |
+
|
28 |
+
class TokenizerConfigs(Enum):
|
29 |
+
CI = continuous_image_dict
|
30 |
+
DI = discrete_image_dict
|
31 |
+
CV = continuous_video_dict
|
32 |
+
DV = discrete_video_dict
|
33 |
+
|
34 |
+
|
35 |
+
class TokenizerModels(Enum):
|
36 |
+
CI = ContinuousImageTokenizer
|
37 |
+
DI = DiscreteImageTokenizer
|
38 |
+
CV = CausalContinuousVideoTokenizer
|
39 |
+
DV = CausalDiscreteVideoTokenizer
|
cosmos_transfer1/auxiliary/tokenizer/networks/configs.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""The default image and video tokenizer configs."""
|
17 |
+
|
18 |
+
from cosmos_transfer1.auxiliary.tokenizer.modules import (
|
19 |
+
ContinuousFormulation,
|
20 |
+
Decoder3DType,
|
21 |
+
DecoderType,
|
22 |
+
DiscreteQuantizer,
|
23 |
+
Encoder3DType,
|
24 |
+
EncoderType,
|
25 |
+
)
|
26 |
+
|
27 |
+
continuous_image = dict(
|
28 |
+
# The attention resolution for res blocks.
|
29 |
+
attn_resolutions=[32],
|
30 |
+
# The base number of channels.
|
31 |
+
channels=128,
|
32 |
+
# The channel multipler for each resolution.
|
33 |
+
channels_mult=[2, 4, 4],
|
34 |
+
dropout=0.0,
|
35 |
+
in_channels=3,
|
36 |
+
# The spatial compression ratio.
|
37 |
+
spatial_compression=16,
|
38 |
+
# The number of layers in each res block.
|
39 |
+
num_res_blocks=2,
|
40 |
+
out_channels=3,
|
41 |
+
resolution=1024,
|
42 |
+
patch_size=4,
|
43 |
+
patch_method="haar",
|
44 |
+
# The output latent dimension (channels).
|
45 |
+
latent_channels=16,
|
46 |
+
# The encoder output channels just before sampling.
|
47 |
+
# Which is also the decoder's input channels.
|
48 |
+
z_channels=16,
|
49 |
+
# A factor over the z_channels, to get the total channels the encoder should output.
|
50 |
+
# For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels.
|
51 |
+
z_factor=1,
|
52 |
+
name="CI",
|
53 |
+
# What formulation to use, either "AE" or "VAE".
|
54 |
+
# Chose VAE here, since the pre-trained ckpt were of a VAE formulation.
|
55 |
+
formulation=ContinuousFormulation.AE.name,
|
56 |
+
# Specify type of encoder ["Default", "LiteVAE"]
|
57 |
+
encoder=EncoderType.Default.name,
|
58 |
+
# Specify type of decoder ["Default"]
|
59 |
+
decoder=DecoderType.Default.name,
|
60 |
+
)
|
61 |
+
|
62 |
+
discrete_image = dict(
|
63 |
+
# The attention resolution for res blocks.
|
64 |
+
attn_resolutions=[32],
|
65 |
+
# The base number of channels.
|
66 |
+
channels=128,
|
67 |
+
# The channel multipler for each resolution.
|
68 |
+
channels_mult=[2, 4, 4],
|
69 |
+
dropout=0.0,
|
70 |
+
in_channels=3,
|
71 |
+
# The spatial compression ratio.
|
72 |
+
spatial_compression=16,
|
73 |
+
# The number of layers in each res block.
|
74 |
+
num_res_blocks=2,
|
75 |
+
out_channels=3,
|
76 |
+
resolution=1024,
|
77 |
+
patch_size=4,
|
78 |
+
patch_method="haar",
|
79 |
+
# The encoder output channels just before sampling.
|
80 |
+
z_channels=256,
|
81 |
+
# A factor over the z_channels, to get the total channels the encoder should output.
|
82 |
+
# for discrete tokenization, often we directly use the vector, so z_factor=1.
|
83 |
+
z_factor=1,
|
84 |
+
# The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ.
|
85 |
+
quantizer=DiscreteQuantizer.FSQ.name,
|
86 |
+
# The embedding dimension post-quantization, which is also the input channels of the decoder.
|
87 |
+
# Which is also the output
|
88 |
+
embedding_dim=6,
|
89 |
+
# The number of levels to use for fine-scalar quantization.
|
90 |
+
levels=[8, 8, 8, 5, 5, 5],
|
91 |
+
# The number of quantizers to use for residual fine-scalar quantization.
|
92 |
+
num_quantizers=4,
|
93 |
+
name="DI",
|
94 |
+
# Specify type of encoder ["Default", "LiteVAE"]
|
95 |
+
encoder=EncoderType.Default.name,
|
96 |
+
# Specify type of decoder ["Default"]
|
97 |
+
decoder=DecoderType.Default.name,
|
98 |
+
)
|
99 |
+
|
100 |
+
continuous_video = dict(
|
101 |
+
attn_resolutions=[32],
|
102 |
+
channels=128,
|
103 |
+
channels_mult=[2, 4, 4],
|
104 |
+
dropout=0.0,
|
105 |
+
in_channels=3,
|
106 |
+
num_res_blocks=2,
|
107 |
+
out_channels=3,
|
108 |
+
resolution=1024,
|
109 |
+
patch_size=4,
|
110 |
+
patch_method="haar",
|
111 |
+
latent_channels=16,
|
112 |
+
z_channels=16,
|
113 |
+
z_factor=1,
|
114 |
+
num_groups=1,
|
115 |
+
legacy_mode=False,
|
116 |
+
spatial_compression=8,
|
117 |
+
temporal_compression=8,
|
118 |
+
formulation=ContinuousFormulation.AE.name,
|
119 |
+
encoder=Encoder3DType.FACTORIZED.name,
|
120 |
+
decoder=Decoder3DType.FACTORIZED.name,
|
121 |
+
name="CV",
|
122 |
+
)
|
123 |
+
|
124 |
+
discrete_video = dict(
|
125 |
+
attn_resolutions=[32],
|
126 |
+
channels=128,
|
127 |
+
channels_mult=[2, 4, 4],
|
128 |
+
dropout=0.0,
|
129 |
+
in_channels=3,
|
130 |
+
num_res_blocks=2,
|
131 |
+
out_channels=3,
|
132 |
+
resolution=1024,
|
133 |
+
patch_size=4,
|
134 |
+
patch_method="haar",
|
135 |
+
z_channels=16,
|
136 |
+
z_factor=1,
|
137 |
+
num_groups=1,
|
138 |
+
legacy_mode=False,
|
139 |
+
spatial_compression=16,
|
140 |
+
temporal_compression=8,
|
141 |
+
quantizer=DiscreteQuantizer.FSQ.name,
|
142 |
+
embedding_dim=6,
|
143 |
+
levels=[8, 8, 8, 5, 5, 5],
|
144 |
+
encoder=Encoder3DType.FACTORIZED.name,
|
145 |
+
decoder=Decoder3DType.FACTORIZED.name,
|
146 |
+
name="DV",
|
147 |
+
)
|